Skip to content

Commit

Permalink
MemoryStore: mongomock -> pymongo-inmemory
Browse files Browse the repository at this point in the history
  • Loading branch information
rkingsbury committed Aug 25, 2023
1 parent f90c211 commit 148523f
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 80 deletions.
5 changes: 5 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# TODO - this entire file can be removed once pymongo-inmemory supports pyproject.toml
# see https://github.com/kaizendorks/pymongo_inmemory/issues/81
[pymongo_inmemory]
use_local_mongod = False
mongod_port = 27019
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
"pydantic<2.0",
"pydantic>=0.32.2",
"pymongo>=4.2.0",
"pymongo-inmemory",
"monty>=1.0.2",
"mongomock>=3.10.0",
"pydash>=4.1.0",
Expand Down
120 changes: 42 additions & 78 deletions src/maggma/stores/mongolike.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""

import warnings
from itertools import chain, groupby
from itertools import chain
from pathlib import Path
from socket import socket

Expand All @@ -18,15 +18,15 @@

from typing_extensions import Literal

import mongomock
import orjson
from monty.dev import requires
from monty.io import zopen
from monty.json import MSONable, jsanitize
from monty.serialization import loadfn
from pydash import get, has, set_
from pydash import has, set_
from pymongo import MongoClient, ReplaceOne, uri_parser
from pymongo.errors import ConfigurationError, DocumentTooLarge, OperationFailure
from pymongo_inmemory import MongoClient as MemoryClient
from sshtunnel import SSHTunnelForwarder

from maggma.core import Sort, Store, StoreError
Expand Down Expand Up @@ -139,10 +139,12 @@ def __init__(
port: TCP port to connect to
username: Username for the collection
password: Password to connect with
ssh_tunnel: SSHTunnel instance to use for connection.
safe_update: fail gracefully on DocumentTooLarge errors on update
auth_source: The database to authenticate on. Defaults to the database name.
default_sort: Default sort field and direction to use when querying. Can be used to
ensure determinacy in query results.
mongoclient_kwargs: Dict of extra kwargs to pass to MongoClient()
"""
self.database = database
self.collection_name = collection_name
Expand Down Expand Up @@ -578,95 +580,57 @@ class MemoryStore(MongoStore):
to a MongoStore
"""

def __init__(self, collection_name: str = "memory_db", **kwargs):
def __init__(
self,
database: str = "mem",
collection_name: str = "memory_store",
host: str = "localhost",
port: int = 27019, # to avoid conflicts with localhost
safe_update: bool = False,
mongoclient_kwargs: Optional[Dict] = None,
default_sort: Optional[Dict[str, Union[Sort, int]]] = None,
**kwargs,
):
"""
Initializes the Memory Store
Args:
collection_name: name for the collection in memory
"""
self.collection_name = collection_name
self.default_sort = None
self._coll = None
self.kwargs = kwargs
super(MongoStore, self).__init__(**kwargs)
database: The database name
collection_name: The collection name
host: Hostname for the database
port: TCP port to connect to
safe_update: fail gracefully on DocumentTooLarge errors on update
default_sort: Default sort field and direction to use when querying.
Can be used to ensure determinacy in query results.
mongoclient_kwargs: Dict of extra kwargs to pass to MongoClient()
"""
super().__init__(
database=database,
collection_name=collection_name,
host=host,
port=port,
safe_update=safe_update,
mongoclient_kwargs=mongoclient_kwargs,
default_sort=default_sort,
**kwargs,
)

def connect(self, force_reset: bool = False):
"""
Connect to the source data
"""
conn: MemoryClient = MemoryClient(
host=self.host,
port=self.port,
**self.mongoclient_kwargs,
)

if self._coll is None or force_reset:
self._coll = mongomock.MongoClient().db[self.name] # type: ignore

def close(self):
"""Close up all collections"""
self._coll.database.client.close()
db = conn[self.database]
self._coll = db[self.collection_name] # type: ignore

@property
def name(self):
"""Name for the store"""
return f"mem://{self.collection_name}"

def __hash__(self):
"""Hash for the store"""
return hash((self.name, self.last_updated_field))

def groupby(
self,
keys: Union[List[str], str],
criteria: Optional[Dict] = None,
properties: Union[Dict, List, None] = None,
sort: Optional[Dict[str, Union[Sort, int]]] = None,
skip: int = 0,
limit: int = 0,
) -> Iterator[Tuple[Dict, List[Dict]]]:
"""
Simple grouping function that will group documents
by keys.
Args:
keys: fields to group documents
criteria: PyMongo filter for documents to search in
properties: properties to return in grouped documents
sort: Dictionary of sort order for fields. Keys are field names and
values are 1 for ascending or -1 for descending.
skip: number documents to skip
limit: limit on total number of documents returned
Returns:
generator returning tuples of (key, list of elements)
"""
keys = keys if isinstance(keys, list) else [keys]

if properties is None:
properties = []
if isinstance(properties, dict):
properties = list(properties.keys())

data = [
doc for doc in self.query(properties=keys + properties, criteria=criteria) if all(has(doc, k) for k in keys)
]

def grouping_keys(doc):
return tuple(get(doc, k) for k in keys)

for vals, group in groupby(sorted(data, key=grouping_keys), key=grouping_keys):
doc = {} # type: ignore
for k, v in zip(keys, vals):
set_(doc, k, v)
yield doc, list(group)

def __eq__(self, other: object) -> bool:
"""
Check equality for MemoryStore
other: other MemoryStore to compare with
"""
if not isinstance(other, MemoryStore):
return False

fields = ["collection_name", "last_updated_field"]
return all(getattr(self, f) == getattr(other, f) for f in fields)


class JSONStore(MemoryStore):
"""
Expand Down
4 changes: 2 additions & 2 deletions tests/stores/test_mongolike.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from pathlib import Path
from unittest import mock

import mongomock.collection
import orjson
import pymongo.collection
import pytest
Expand Down Expand Up @@ -238,8 +237,9 @@ def test_mongostore_newer_in(mongostore):
def test_memory_store_connect():
memorystore = MemoryStore()
assert memorystore._coll is None
assert "mem:" in memorystore.name
memorystore.connect()
assert isinstance(memorystore._collection, mongomock.collection.Collection)
assert isinstance(memorystore._collection, pymongo.collection.Collection)


def test_groupby(memorystore):
Expand Down

0 comments on commit 148523f

Please sign in to comment.