diff --git a/safemap/safemap.go b/safemap/safemap.go index 32ba3a3..e3b1465 100644 --- a/safemap/safemap.go +++ b/safemap/safemap.go @@ -3,45 +3,43 @@ package safemap import "sync" type SafeMap[K comparable, V any] struct { - mu sync.RWMutex - data map[K]V + data sync.Map } func New[K comparable, V any]() *SafeMap[K, V] { return &SafeMap[K, V]{ - data: make(map[K]V), + data: sync.Map{}, } } func (s *SafeMap[K, V]) Set(k K, v V) { - s.mu.Lock() - defer s.mu.Unlock() - s.data[k] = v + s.data.Store(k, v) } func (s *SafeMap[K, V]) Get(k K) (V, bool) { - s.mu.RLock() - defer s.mu.RUnlock() - val, ok := s.data[k] - return val, ok + val, ok := s.data.Load(k) + if !ok { + return *new(V), false // Return zero value of type V and false + } + return val.(V), ok } func (s *SafeMap[K, V]) Delete(k K) { - s.mu.Lock() - defer s.mu.Unlock() - delete(s.data, k) + s.data.Delete(k) } func (s *SafeMap[K, V]) Len() int { - s.mu.RLock() - defer s.mu.RUnlock() - return len(s.data) + count := 0 + s.data.Range(func(_, _ interface{}) bool { + count++ + return true + }) + return count } func (s *SafeMap[K, V]) ForEach(f func(K, V)) { - s.mu.RLock() - defer s.mu.RUnlock() - for key, val := range s.data { - f(key, val) - } + s.data.Range(func(key, value interface{}) bool { + f(key.(K), value.(V)) + return true + }) } diff --git a/safemap/safemap_test.go b/safemap/safemap_test.go new file mode 100644 index 0000000..e9da878 --- /dev/null +++ b/safemap/safemap_test.go @@ -0,0 +1,121 @@ +package safemap + +import ( + "math/rand" + "testing" +) + +func TestNewAndSet(t *testing.T) { + sm := New[int, string]() + sm.Set(1, "one") + sm.Set(2, "two") + + if sm.Len() != 2 { + t.Errorf("Expected length 2, got %d", sm.Len()) + } +} + +func TestGet(t *testing.T) { + sm := New[int, string]() + sm.Set(1, "one") + + val, ok := sm.Get(1) + if !ok || val != "one" { + t.Errorf("Expected 'one', got %s", val) + } + _, ok = sm.Get(2) + if ok { + t.Errorf("Expected false, got true") + } +} + +func TestDelete(t *testing.T) { + sm := New[int, string]() + sm.Set(1, "one") + sm.Delete(1) + _, ok := sm.Get(1) + if ok { + t.Errorf("Expected key 1 to be deleted") + } + + sm.Delete(2) // just make sure this doesn't panic +} + +func TestLen(t *testing.T) { + sm := New[int, string]() + sm.Set(1, "one") + sm.Set(2, "two") + sm.Set(3, "three") + sm.Delete(2) + + if sm.Len() != 2 { + t.Errorf("Expected length 2, got %d", sm.Len()) + } +} + +func TestForEach(t *testing.T) { + sm := New[int, string]() + sm.Set(1, "one") + sm.Set(2, "two") + + keys := make([]int, 0) + sm.ForEach(func(k int, v string) { + keys = append(keys, k) + }) + + if len(keys) != 2 { + t.Errorf("Expected 2 keys, got %d", len(keys)) + } + + // Check if keys 1 and 2 are present + if !contains(keys, 1) || !contains(keys, 2) { + t.Errorf("Expected keys 1 and 2, got %v", keys) + } +} + +// Helper function to check if a slice contains a specific element. +func contains(slice []int, element int) bool { + for _, a := range slice { + if a == element { + return true + } + } + return false +} + +// Benchmark Get +func BenchmarkGetConcurrent(b *testing.B) { + ds := New[uint64, uint64]() + // Pre-fill the data store with test data if needed + for i := 0; i < 100000; i++ { + ds.Set(uint64(i), uint64(i)) + } + b.SetParallelism(100) + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + r := rand.Uint64() % 100000 + _, _ = ds.Get(r) + } + }) +} + +// fool the compiler +var silly uint64 + +// Benchmark Set +func BenchmarkSetConcurrent(b *testing.B) { + ds := New[uint64, uint64]() + b.SetParallelism(100) + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + r := rand.Uint64() % 100000 + silly, _ = ds.Get(r) + ds.Set(r, r) + silly, _ = ds.Get(r) + + } + }) +}