diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 423b0a2fc..a627117ed 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -51,7 +51,9 @@ jobs: max-parallel: 6 matrix: os: [ubuntu-latest] - python-version: ["3.8", "3.9", "3.10", "3.11"] + # python-version: ["3.8", "3.9", "3.10", "3.11"] + # temporary workaround since pymongo-inmemory dropped 3.8 support + python-version: ["3.9", "3.10", "3.11"] runs-on: ${{ matrix.os }} @@ -72,6 +74,8 @@ jobs: env: CONTINUOUS_INTEGRATION: True MONGODB_SRV_URI: ${{ secrets.MONGODB_SRV_URI }} + PYMONGOIM__OPERATING_SYSTEM: ubuntu + PYMONGOIM__MONGO_VERSION: 6.0 run: | pip install -e . pytest --cov=maggma --cov-report=xml diff --git a/pyproject.toml b/pyproject.toml index 30f5692ca..6be0ca5e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ dependencies = [ "pymongo>=4.2.0", "monty>=2023.9.25", "mongomock>=3.10.0", + "pymongo-inmemory>=0.4.1", "pydash>=4.1.0", "jsonschema>=3.1.1", "tqdm>=4.19.6", diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 000000000..a75e9baed --- /dev/null +++ b/setup.cfg @@ -0,0 +1,5 @@ +# TODO - this entire file can be removed once pymongo-inmemory supports pyproject.toml +# see https://github.com/kaizendorks/pymongo_inmemory/issues/81 +[pymongo_inmemory] +mongod_port = 27019 +mongo_version = 6.0 diff --git a/src/maggma/stores/file_store.py b/src/maggma/stores/file_store.py index fa1a3296a..72e14ddcf 100644 --- a/src/maggma/stores/file_store.py +++ b/src/maggma/stores/file_store.py @@ -102,7 +102,6 @@ def __init__( self.json_name = json_name file_filters = file_filters if file_filters else ["*"] self.file_filters = re.compile("|".join(fnmatch.translate(p) for p in file_filters)) - self.collection_name = "file_store" self.key = "file_id" self.include_orphans = include_orphans self.read_only = read_only @@ -112,14 +111,12 @@ def __init__( self.metadata_store = JSONStore( paths=[str(self.path / self.json_name)], read_only=self.read_only, - collection_name=self.collection_name, key=self.key, ) self.kwargs = kwargs super().__init__( - collection_name=self.collection_name, key=self.key, **self.kwargs, ) diff --git a/src/maggma/stores/mongolike.py b/src/maggma/stores/mongolike.py index 6d431aa1a..25a4f6b9f 100644 --- a/src/maggma/stores/mongolike.py +++ b/src/maggma/stores/mongolike.py @@ -5,6 +5,7 @@ """ import warnings +from datetime import datetime from itertools import chain, groupby from pathlib import Path @@ -20,7 +21,6 @@ from typing_extensions import Literal import bson -import mongomock import orjson from monty.dev import requires from monty.io import zopen @@ -29,6 +29,7 @@ from pydash import get, has, set_ from pymongo import MongoClient, ReplaceOne, uri_parser from pymongo.errors import ConfigurationError, DocumentTooLarge, OperationFailure +from pymongo_inmemory import MongoClient as MemoryClient from maggma.core import Sort, Store, StoreError from maggma.utils import confirm_field_index, to_dt @@ -67,10 +68,12 @@ def __init__( port: TCP port to connect to username: Username for the collection password: Password to connect with + ssh_tunnel: SSHTunnel instance to use for connection. safe_update: fail gracefully on DocumentTooLarge errors on update auth_source: The database to authenticate on. Defaults to the database name. default_sort: Default sort field and direction to use when querying. Can be used to ensure determinacy in query results. + mongoclient_kwargs: Dict of extra kwargs to pass to MongoClient() """ self.database = database self.collection_name = collection_name @@ -437,7 +440,7 @@ def __eq__(self, other: object) -> bool: Check equality for MongoStore other: other mongostore to compare with. """ - if not isinstance(other, MongoStore): + if not isinstance(other, self.__class__): return False fields = ["database", "collection_name", "host", "port", "last_updated_field"] @@ -516,98 +519,67 @@ class MemoryStore(MongoStore): to a MongoStore. """ - def __init__(self, collection_name: str = "memory_db", **kwargs): + def __init__( + self, + database: str = "mem", + collection_name: Optional[str] = None, + host: str = "localhost", + port: int = 27019, # to avoid conflicts with localhost + safe_update: bool = False, + mongoclient_kwargs: Optional[Dict] = None, + default_sort: Optional[Dict[str, Union[Sort, int]]] = None, + **kwargs, + ): """ - Initializes the Memory Store. - Args: - collection_name: name for the collection in memory. - """ - self.collection_name = collection_name - self.default_sort = None - self._coll = None - self.kwargs = kwargs - super(MongoStore, self).__init__(**kwargs) + database: The database name + collection_name: The collection name. If None (default) a unique collection name is set based + on the current date and time. This ensures that multiple Store instances can coexist without + overwriting one another. + host: Hostname for the database + port: TCP port to connect to + safe_update: fail gracefully on DocumentTooLarge errors on update + default_sort: Default sort field and direction to use when querying. + Can be used to ensure determinacy in query results. + mongoclient_kwargs: Dict of extra kwargs to pass to MongoClient() + """ + if not collection_name: + collection_name = str(datetime.utcnow()) + super().__init__( + database=database, + collection_name=collection_name, + host=host, + port=port, + safe_update=safe_update, + mongoclient_kwargs=mongoclient_kwargs, + default_sort=default_sort, + **kwargs, + ) def connect(self, force_reset: bool = False): """ - Connect to the source data. - - Args: - force_reset: whether to reset the connection or not when the Store is - already connected. + Connect to the source data """ if self._coll is None or force_reset: - self._coll = mongomock.MongoClient().db[self.name] # type: ignore - - def close(self): - """Close up all collections.""" - self._coll.database.client.close() + conn: MemoryClient = MemoryClient( + host=self.host, + port=self.port, + **self.mongoclient_kwargs, + ) + db = conn[self.database] + self._coll = db[self.collection_name] # type: ignore @property def name(self): - """Name for the store.""" - return f"mem://{self.collection_name}" + """Name for the store""" + return f"mem://{self.database}/{self.collection_name}" - def __hash__(self): - """Hash for the store.""" - return hash((self.name, self.last_updated_field)) - - def groupby( - self, - keys: Union[List[str], str], - criteria: Optional[Dict] = None, - properties: Union[Dict, List, None] = None, - sort: Optional[Dict[str, Union[Sort, int]]] = None, - skip: int = 0, - limit: int = 0, - ) -> Iterator[Tuple[Dict, List[Dict]]]: - """ - Simple grouping function that will group documents - by keys. - - Args: - keys: fields to group documents - criteria: PyMongo filter for documents to search in - properties: properties to return in grouped documents - sort: Dictionary of sort order for fields. Keys are field names and - values are 1 for ascending or -1 for descending. - skip: number documents to skip - limit: limit on total number of documents returned - - Returns: - generator returning tuples of (key, list of elements) - """ - keys = keys if isinstance(keys, list) else [keys] - - if properties is None: - properties = [] - if isinstance(properties, dict): - properties = list(properties.keys()) - - data = [ - doc for doc in self.query(properties=keys + properties, criteria=criteria) if all(has(doc, k) for k in keys) - ] - - def grouping_keys(doc): - return tuple(get(doc, k) for k in keys) - - for vals, group in groupby(sorted(data, key=grouping_keys), key=grouping_keys): - doc = {} # type: ignore - for k, v in zip(keys, vals): - set_(doc, k, v) - yield doc, list(group) - - def __eq__(self, other: object) -> bool: - """ - Check equality for MemoryStore - other: other MemoryStore to compare with. - """ - if not isinstance(other, MemoryStore): - return False - - fields = ["collection_name", "last_updated_field"] - return all(getattr(self, f) == getattr(other, f) for f in fields) + # def __del__(self): + # """ + # Ensure collection is dropped from memory on object destruction, even if .close() has not been called. + # """ + # if self._coll is not None: + # self._collection.drop() class JSONStore(MemoryStore): @@ -689,7 +661,8 @@ def connect(self, force_reset: bool = False): on systems with slow storage when multiple connect / disconnects are performed. """ if self._coll is None or force_reset: - self._coll = mongomock.MongoClient().db[self.name] # type: ignore + # self._coll = mongomock.MongoClient().db[self.name] # type: ignore + self._coll = MemoryClient().db[self.name] # type: ignore # create the .json file if it does not exist if not self.read_only and not Path(self.paths[0]).exists(): @@ -791,7 +764,7 @@ def __eq__(self, other: object) -> bool: Args: other: other JSONStore to compare with """ - if not isinstance(other, JSONStore): + if not isinstance(other, self.__class__): return False fields = ["paths", "last_updated_field"] @@ -803,7 +776,7 @@ def __eq__(self, other: object) -> bool: "MontyStore requires MontyDB to be installed. See the MontyDB repository for more " "information: https://github.com/davidlatwe/montydb", ) -class MontyStore(MemoryStore): +class MontyStore(MongoStore): """ A MongoDB compatible store that uses on disk files for storage. @@ -944,3 +917,59 @@ def update(self, docs: Union[List[Dict], Dict], key: Union[List, str, None] = No search_doc = {k: d[k] for k in key} if isinstance(key, list) else {key: d[key]} self._collection.replace_one(search_doc, d, upsert=True) + + def groupby( + self, + keys: Union[List[str], str], + criteria: Optional[Dict] = None, + properties: Union[Dict, List, None] = None, + sort: Optional[Dict[str, Union[Sort, int]]] = None, + skip: int = 0, + limit: int = 0, + ) -> Iterator[Tuple[Dict, List[Dict]]]: + """ + Simple grouping function that will group documents + by keys. + + Args: + keys: fields to group documents + criteria: PyMongo filter for documents to search in + properties: properties to return in grouped documents + sort: Dictionary of sort order for fields. Keys are field names and + values are 1 for ascending or -1 for descending. + skip: number documents to skip + limit: limit on total number of documents returned + + Returns: + generator returning tuples of (key, list of elements) + """ + keys = keys if isinstance(keys, list) else [keys] + + if properties is None: + properties = [] + if isinstance(properties, dict): + properties = list(properties.keys()) + + data = [ + doc for doc in self.query(properties=keys + properties, criteria=criteria) if all(has(doc, k) for k in keys) + ] + + def grouping_keys(doc): + return tuple(get(doc, k) for k in keys) + + for vals, group in groupby(sorted(data, key=grouping_keys), key=grouping_keys): + doc = {} # type: ignore + for k, v in zip(keys, vals): + set_(doc, k, v) + yield doc, list(group) + + def __eq__(self, other: object) -> bool: + """ + Check equality for MontyStore + other: other Store to compare with + """ + if not isinstance(other, self.__class__): + return False + + fields = ["database_name", "collection_name", "last_updated_field"] + return all(getattr(self, f) == getattr(other, f) for f in fields) diff --git a/tests/api/test_submission_resource.py b/tests/api/test_submission_resource.py index 8ef7acd45..60cd68475 100644 --- a/tests/api/test_submission_resource.py +++ b/tests/api/test_submission_resource.py @@ -1,4 +1,3 @@ -import json from datetime import datetime from random import randint @@ -113,10 +112,10 @@ def test_submission_patch(owner_store, post_query_op, patch_query_op): app.include_router(endpoint.router) client = TestClient(app) - update = json.dumps({"last_updated": "2023-06-22T17:32:11.645713"}) + update = {"last_updated": "2023-06-22T17:32:11.645713"} assert client.get("/").status_code == 200 - assert client.patch(f"/?name=PersonAge9&update={update}").status_code == 200 + assert client.patch("/?name=PersonAge9", json=update).status_code == 200 def test_key_fields(owner_store, post_query_op): diff --git a/tests/builders/test_copy_builder.py b/tests/builders/test_copy_builder.py index d8c0bfde1..46ce64b03 100644 --- a/tests/builders/test_copy_builder.py +++ b/tests/builders/test_copy_builder.py @@ -102,7 +102,8 @@ def test_run(source, target, old_docs, new_docs): builder = CopyBuilder(source, target) builder.run() - builder.target.connect() + + target.connect() assert builder.target.query_one(criteria={"k": 0})["v"] == "new" assert builder.target.query_one(criteria={"k": 10})["v"] == "old" @@ -113,6 +114,8 @@ def test_query(source, target, old_docs, new_docs): source.update(old_docs) source.update(new_docs) builder.run() + + target.connect() all_docs = list(target.query(criteria={})) assert len(all_docs) == 14 assert min([d["k"] for d in all_docs]) == 6 @@ -128,6 +131,7 @@ def test_delete_orphans(source, target, old_docs, new_docs): source._collection.delete_many(deletion_criteria) builder.run() + target.connect() assert target._collection.count_documents(deletion_criteria) == 0 assert target.query_one(criteria={"k": 5})["v"] == "new" assert target.query_one(criteria={"k": 10})["v"] == "old" diff --git a/tests/builders/test_projection_builder.py b/tests/builders/test_projection_builder.py index 345947e4f..6f6636e8b 100644 --- a/tests/builders/test_projection_builder.py +++ b/tests/builders/test_projection_builder.py @@ -103,6 +103,8 @@ def test_update_targets(source1, source2, target): def test_run(source1, source2, target): builder = Projection_Builder(source_stores=[source1, source2], target_store=target) builder.run() + + target.connect() assert len(list(target.query())) == 15 assert target.query_one(criteria={"k": 0})["a"] == "a" assert target.query_one(criteria={"k": 0})["d"] == "d" @@ -118,4 +120,6 @@ def test_query(source1, source2, target): query_by_key=[0, 1, 2, 3, 4], ) builder.run() + + target.connect() assert len(list(target.query())) == 5 diff --git a/tests/cli/test_init.py b/tests/cli/test_init.py index a53bd0408..fc393a252 100644 --- a/tests/cli/test_init.py +++ b/tests/cli/test_init.py @@ -31,6 +31,11 @@ def reporting_store(): store._collection.drop() +@pytest.fixture() +def memorystore(): + return MemoryStore("temp") + + def test_basic_run(): runner = CliRunner() result = runner.invoke(run, ["--help"]) @@ -41,8 +46,7 @@ def test_basic_run(): assert result.exit_code != 0 -def test_run_builder(mongostore): - memorystore = MemoryStore("temp") +def test_run_builder(mongostore, memorystore): builder = CopyBuilder(mongostore, memorystore) mongostore.update([{mongostore.key: i, mongostore.last_updated_field: datetime.utcnow()} for i in range(10)]) @@ -57,7 +61,7 @@ def test_run_builder(mongostore): result = runner.invoke(run, ["-vvv", "--no_bars", "test_builder.json"]) assert result.exit_code == 0 - assert "Get" not in result.output + assert "Get " not in result.output assert "Update" not in result.output result = runner.invoke(run, ["-v", "-n", "2", "test_builder.json"]) @@ -67,12 +71,11 @@ def test_run_builder(mongostore): result = runner.invoke(run, ["-vvv", "-n", "2", "--no_bars", "test_builder.json"]) assert result.exit_code == 0 - assert "Get" not in result.output + assert "Get " not in result.output assert "Update" not in result.output -def test_run_builder_chain(mongostore): - memorystore = MemoryStore("temp") +def test_run_builder_chain(mongostore, memorystore): builder1 = CopyBuilder(mongostore, memorystore) builder2 = CopyBuilder(mongostore, memorystore) @@ -88,7 +91,7 @@ def test_run_builder_chain(mongostore): result = runner.invoke(run, ["-vvv", "--no_bars", "test_builders.json"]) assert result.exit_code == 0 - assert "Get" not in result.output + assert "Get " not in result.output assert "Update" not in result.output result = runner.invoke(run, ["-v", "-n", "2", "test_builders.json"]) @@ -98,12 +101,11 @@ def test_run_builder_chain(mongostore): result = runner.invoke(run, ["-vvv", "-n", "2", "--no_bars", "test_builders.json"]) assert result.exit_code == 0 - assert "Get" not in result.output + assert "Get " not in result.output assert "Update" not in result.output -def test_reporting(mongostore, reporting_store): - memorystore = MemoryStore("temp") +def test_reporting(mongostore, reporting_store, memorystore): builder = CopyBuilder(mongostore, memorystore) mongostore.update([{mongostore.key: i, mongostore.last_updated_field: datetime.utcnow()} for i in range(10)]) @@ -155,8 +157,7 @@ def test_python_notebook_source(): assert "Ended multiprocessing: DummyBuilder" in result.output -def test_memray_run_builder(mongostore): - memorystore = MemoryStore("temp") +def test_memray_run_builder(mongostore, memorystore): builder = CopyBuilder(mongostore, memorystore) mongostore.update([{mongostore.key: i, mongostore.last_updated_field: datetime.utcnow()} for i in range(10)]) @@ -171,7 +172,7 @@ def test_memray_run_builder(mongostore): result = runner.invoke(run, ["-vvv", "--no_bars", "--memray", "on", "test_builder.json"]) assert result.exit_code == 0 - assert "Get" not in result.output + assert "Get " not in result.output assert "Update" not in result.output result = runner.invoke(run, ["-v", "-n", "2", "--memray", "on", "test_builder.json"]) @@ -181,12 +182,11 @@ def test_memray_run_builder(mongostore): result = runner.invoke(run, ["-vvv", "-n", "2", "--no_bars", "--memray", "on", "test_builder.json"]) assert result.exit_code == 0 - assert "Get" not in result.output + assert "Get " not in result.output assert "Update" not in result.output -def test_memray_user_output_dir(mongostore): - memorystore = MemoryStore("temp") +def test_memray_user_output_dir(mongostore, memorystore): builder = CopyBuilder(mongostore, memorystore) mongostore.update([{mongostore.key: i, mongostore.last_updated_field: datetime.utcnow()} for i in range(10)]) diff --git a/tests/stores/test_advanced_stores.py b/tests/stores/test_advanced_stores.py index 4695632d3..5d5c5a433 100644 --- a/tests/stores/test_advanced_stores.py +++ b/tests/stores/test_advanced_stores.py @@ -285,7 +285,8 @@ def sandbox_store(): memstore = MemoryStore() store = SandboxStore(memstore, sandbox="test") store.connect() - return store + yield store + store._collection.drop() def test_sandbox_count(sandbox_store): @@ -313,10 +314,10 @@ def test_sandbox_distinct(sandbox_store): assert sandbox_store.distinct("a") == [1] sandbox_store._collection.insert_one({"a": 4, "d": 5, "e": 6, "sbxn": ["test"]}) - assert sandbox_store.distinct("a")[1] == 4 + assert set(sandbox_store.distinct("a")) == {4, 1} sandbox_store._collection.insert_one({"a": 7, "d": 8, "e": 9, "sbxn": ["not_test"]}) - assert sandbox_store.distinct("a")[1] == 4 + assert set(sandbox_store.distinct("a")) == {4, 1} def test_sandbox_update(sandbox_store): diff --git a/tests/stores/test_aws.py b/tests/stores/test_aws.py index 1c7167048..c7a9b4533 100644 --- a/tests/stores/test_aws.py +++ b/tests/stores/test_aws.py @@ -7,6 +7,7 @@ from moto import mock_s3 from sshtunnel import BaseSSHTunnelForwarderError +from maggma.core.store import StoreError from maggma.stores import MemoryStore, MongoStore, S3Store from maggma.stores.ssh_tunnel import SSHTunnel @@ -229,7 +230,7 @@ def objects_in_bucket(key): def test_close(s3store): list(s3store.query()) s3store.close() - with pytest.raises(AttributeError): + with pytest.raises(StoreError): list(s3store.query()) diff --git a/tests/stores/test_mongolike.py b/tests/stores/test_mongolike.py index 997b7edde..bdd3e3629 100644 --- a/tests/stores/test_mongolike.py +++ b/tests/stores/test_mongolike.py @@ -4,7 +4,6 @@ from pathlib import Path from unittest import mock -import mongomock.collection import orjson import pymongo.collection import pytest @@ -29,10 +28,11 @@ def mongostore(): @pytest.fixture() -def montystore(tmp_dir): +def montystore(): store = MontyStore("maggma_test") store.connect() - return store + yield store + store._collection.drop() @pytest.fixture() @@ -240,8 +240,9 @@ def test_mongostore_newer_in(mongostore): def test_memory_store_connect(): memorystore = MemoryStore() assert memorystore._coll is None + assert "mem:" in memorystore.name memorystore.connect() - assert isinstance(memorystore._collection, mongomock.collection.Collection) + assert isinstance(memorystore._collection, pymongo.collection.Collection) def test_groupby(memorystore): @@ -281,7 +282,7 @@ def test_groupby(memorystore): # Monty store tests -def test_monty_store_connect(tmp_dir): +def test_monty_store_connect(): montystore = MontyStore(collection_name="my_collection") assert montystore._coll is None montystore.connect() @@ -551,14 +552,50 @@ def test_jsonstore_last_updated(test_dir): assert jsonstore.last_updated > start_time -def test_eq(mongostore, memorystore, jsonstore): +def test_eq(mongostore, memorystore, jsonstore, montystore): + assert montystore == montystore assert mongostore == mongostore assert memorystore == memorystore assert jsonstore == jsonstore assert mongostore != memorystore + assert mongostore != montystore assert mongostore != jsonstore assert memorystore != jsonstore + assert memorystore != montystore + + # test case courtesy of @sivonxay + + # two MemoryStore with the same collection name point to the same db in memory + store1 = MemoryStore(collection_name="test_collection") + store2 = MemoryStore(collection_name="test_collection") + store1.connect() + store2.connect() + assert store1 == store2 + store1.update([{"a": 1, "b": 2}, {"a": 2, "b": 3}], "a") + assert store2.count() == 2 + + # with different collection names, they do not + store1 = MemoryStore(collection_name="store1") + store2 = MemoryStore(collection_name="store2") + assert store1 != store2 + + store1.connect() + store2.connect() + + store1.update([{"a": 1, "b": 2}, {"a": 2, "b": 3}], "a") + assert store1.count() != store2.count() + + # same with default collection name, which is unique per-instance + store1 = MemoryStore(collection_name=None) + store2 = MemoryStore(collection_name=None) + assert store1 != store2 + + store1.connect() + store2.connect() + + store1.update([{"a": 1, "b": 2}, {"a": 2, "b": 3}], "a") + assert store1.count() != store2.count() @pytest.mark.skipif(