Skip to content
This repository has been archived by the owner on Apr 4, 2024. It is now read-only.

Update ArrayMap #50

Merged
merged 5 commits into from
Apr 3, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 88 additions & 80 deletions python/selfie-lib/selfie_lib/ArrayMap.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,27 @@
from abc import ABC, abstractmethod
from collections.abc import Set, Iterator, Mapping
from typing import List, TypeVar, Union, Any, Tuple
import bisect


class Comparable:
def __lt__(self, other: Any) -> bool:
return NotImplemented

def __le__(self, other: Any) -> bool:
return NotImplemented

def __gt__(self, other: Any) -> bool:
return NotImplemented

def __ge__(self, other: Any) -> bool:
return NotImplemented

from typing import List, TypeVar, Union, Any
from abc import abstractmethod, ABC
from functools import total_ordering

T = TypeVar("T")
V = TypeVar("V")
K = TypeVar("K", bound=Comparable)
K = TypeVar("K")


def string_slash_first_comparator(a: Any, b: Any) -> int:
"""Special comparator for strings where '/' is considered the lowest."""
if isinstance(a, str) and isinstance(b, str):
return (a.replace("/", "\0"), a) < (b.replace("/", "\0"), b)
return (a < b) - (a > b)
@total_ordering
class Comparable:
def __init__(self, value):
self.value = value

def __lt__(self, other: Any) -> bool:
if not isinstance(other, Comparable):
return NotImplemented
return self.value < other.value

def __eq__(self, other: Any) -> bool:
if not isinstance(other, Comparable):
return NotImplemented
return self.value == other.value
nedtwigg marked this conversation as resolved.
Show resolved Hide resolved


class ListBackedSet(Set[T], ABC):
Expand All @@ -37,15 +31,25 @@ def __len__(self) -> int: ...
@abstractmethod
def __getitem__(self, index: Union[int, slice]) -> Union[T, List[T]]: ...

def __contains__(self, item: object) -> bool:
try:
index = self.__binary_search(item)
except ValueError:
return False
return index >= 0

@abstractmethod
def __binary_search(self, item: Any) -> int: ...
def __contains__(self, item: Any) -> bool:
return self._binary_search(item) >= 0

def _binary_search(self, item: Any) -> int:
low = 0
high = len(self) - 1
while low <= high:
mid = (low + high) // 2
try:
mid_val = self[mid]
if mid_val < item:
low = mid + 1
elif mid_val > item:
high = mid - 1
else:
return mid # item found
except TypeError:
raise ValueError(f"Cannot compare items due to a type mismatch.")
return -(low + 1) # item not found


class ArraySet(ListBackedSet[K]):
Expand All @@ -60,6 +64,9 @@ def __create(cls, data: List[K]) -> "ArraySet[K]":
instance.__data = data
return instance

def __iter__(self) -> Iterator[K]:
return iter(self.__data)
nedtwigg marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def empty(cls) -> "ArraySet[K]":
if not hasattr(cls, "__EMPTY"):
Expand All @@ -72,74 +79,75 @@ def __len__(self) -> int:
def __getitem__(self, index: Union[int, slice]) -> Union[K, List[K]]:
return self.__data[index]

def __binary_search(self, item: K) -> int:
if isinstance(item, str):
key = lambda x: x.replace("/", "\0")
return (
bisect.bisect_left(self.__data, item, key=key) - 1
if item in self.__data
else -1
)
return bisect.bisect_left(self.__data, item) - 1 if item in self.__data else -1

def plusOrThis(self, element: K) -> "ArraySet[K]":
index = self.__binary_search(element)
if index >= 0:
if element in self:
return self
new_data = self.__data[:]
bisect.insort_left(new_data, element)
return ArraySet.__create(new_data)
else:
new_data = self.__data[:]
new_data.append(element)
new_data.sort(key=Comparable)
return ArraySet.__create(new_data)


class ArrayMap(Mapping[K, V]):
__data: List[Tuple[K, V]]

def __init__(self):
raise NotImplementedError("Use ArrayMap.empty() or other class methods instead")

@classmethod
def __create(cls, data: List[Tuple[K, V]]) -> "ArrayMap[K, V]":
instance = super().__new__(cls)
instance.__data = data
return instance
def __init__(self, data=None):
if data is None:
self.__data = []
else:
self.__data = data
nedtwigg marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def empty(cls) -> "ArrayMap[K, V]":
if not hasattr(cls, "__EMPTY"):
cls.__EMPTY = cls.__create([])
cls.__EMPTY = cls([])
return cls.__EMPTY

def __getitem__(self, key: K) -> V:
index = self.__binary_search_key(key)
index = self._binary_search_key(key)
if index >= 0:
return self.__data[index][1]
return self.__data[2 * index + 1]
raise KeyError(key)

def __iter__(self) -> Iterator[K]:
return (key for key, _ in self.__data)
return (self.__data[i] for i in range(0, len(self.__data), 2))

def __len__(self) -> int:
return len(self.__data)

def __binary_search_key(self, key: K) -> int:
keys = [k for k, _ in self.__data]
index = bisect.bisect_left(keys, key)
if index < len(keys) and keys[index] == key:
return index
return -1
return len(self.__data) // 2

def _binary_search_key(self, key: K) -> int:
def compare(a, b):
"""Comparator that puts '/' first in strings."""
if isinstance(a, str) and isinstance(b, str):
a, b = a.replace("/", "\0"), b.replace("/", "\0")
return (a > b) - (a < b)

low, high = 0, len(self.__data) // 2 - 1
while low <= high:
mid = (low + high) // 2
mid_key = self.__data[2 * mid]
comparison = compare(mid_key, key)
if comparison < 0:
low = mid + 1
elif comparison > 0:
high = mid - 1
else:
return mid # key found
return -(low + 1) # key not found

nedtwigg marked this conversation as resolved.
Show resolved Hide resolved
def plus(self, key: K, value: V) -> "ArrayMap[K, V]":
index = self.__binary_search_key(key)
index = self._binary_search_key(key)
if index >= 0:
raise ValueError("Key already exists")
insert_at = -(index + 1)
new_data = self.__data[:]
bisect.insort_left(new_data, (key, value))
return ArrayMap.__create(new_data)
new_data.insert(insert_at * 2, key)
new_data.insert(insert_at * 2 + 1, value)
return ArrayMap(new_data)

def minus_sorted_indices(self, indicesToRemove: List[int]) -> "ArrayMap[K, V]":
if not indicesToRemove:
return self
new_data = [
item for i, item in enumerate(self.__data) if i not in indicesToRemove
]
return ArrayMap.__create(new_data)
def minus_sorted_indices(self, indices: List[int]) -> "ArrayMap[K, V]":
new_data = self.__data[:]
adjusted_indices = [i * 2 for i in indices] + [i * 2 + 1 for i in indices]
adjusted_indices.sort()
for index in reversed(adjusted_indices):
del new_data[index]
return ArrayMap(new_data)