Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[#2685] implement a thread-safe switch db context manager #2686

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mongoengine/base/metaclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def __new__(mcs, name, bases, attrs):
) = mcs._import_classes()

if issubclass(new_class, Document):
new_class._collection = None
new_class._collections = {}

# Add class to the _document_registry
_document_registry[new_class._class_name] = new_class
Expand Down
40 changes: 40 additions & 0 deletions mongoengine/connection.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import warnings
from threading import local

from pymongo import MongoClient, ReadPreference, uri_parser
from pymongo.database import _check_name

from mongoengine.errors import DatabaseAliasError
from mongoengine.pymongo_support import PYMONGO_VERSION

__all__ = [
Expand All @@ -15,6 +17,9 @@
"get_connection",
"get_db",
"register_connection",
"set_local_db_alias",
"del_local_db_alias",
"get_local_db_alias"
]


Expand All @@ -26,6 +31,7 @@
_connection_settings = {}
_connections = {}
_dbs = {}
_local = local()

READ_PREFERENCE = ReadPreference.PRIMARY

Expand Down Expand Up @@ -372,7 +378,41 @@ def _clean_settings(settings_dict):
return _connections[db_alias]


def __local_db_alias():
if getattr(_local, "db_alias", None) is None:
_local.db_alias = {}
return _local.db_alias


def set_local_db_alias(local_alias, alias=DEFAULT_CONNECTION_NAME):
if not alias or not local_alias:
raise DatabaseAliasError(f"db alias and local_alias cannot be empty")

if alias not in __local_db_alias():
__local_db_alias()[alias] = []

__local_db_alias()[alias].append(local_alias)


def del_local_db_alias(alias):
if not alias:
raise DatabaseAliasError(f"db alias cannot be empty")

if alias not in __local_db_alias() or not __local_db_alias()[alias]:
raise DatabaseAliasError(f"local db alias not set: {alias}")

__local_db_alias()[alias].pop()


def get_local_db_alias(alias):
if alias in __local_db_alias() and __local_db_alias()[alias]:
alias = __local_db_alias()[alias][-1]
return alias


def get_db(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
alias = get_local_db_alias(alias)

if reconnect:
disconnect(alias)

Expand Down
45 changes: 40 additions & 5 deletions mongoengine/context_managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
from pymongo.write_concern import WriteConcern

from mongoengine.common import _import_class
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db, set_local_db_alias, del_local_db_alias
from mongoengine.pymongo_support import count_documents

__all__ = (
"switch_db_local",
"switch_db",
"switch_collection",
"no_dereference",
Expand All @@ -18,6 +19,36 @@
)


class switch_db_local:
"""switch_db_local alias context manager.

Switches a db alias in a thread-safe way.

Example ::
register_connection('testdb-1', 'mongoenginetest1')
register_connection('testdb-2', 'mongoenginetest2')

class Group(Document):
name = StringField()

# The following two calls to save() could be run concurrently
with switch_db_local('testdb-1'):
Group(name='test').save()
with switch_db_local('testdb-2'):
Group(name='test').save()
"""

def __init__(self, local_alias, alias=DEFAULT_CONNECTION_NAME):
self.local_alias = local_alias
self.alias = alias

def __enter__(self):
set_local_db_alias(self.local_alias, self.alias)

def __exit__(self, t, value, traceback):
del_local_db_alias(self.alias)


class switch_db:
"""switch_db alias context manager.

Expand Down Expand Up @@ -50,18 +81,22 @@ def __init__(self, cls, db_alias):
def __enter__(self):
"""Change the db_alias and clear the cached collection."""
self.cls._meta["db_alias"] = self.db_alias
self.cls._collection = None
self.cls._set_collection(None)
return self.cls

def __exit__(self, t, value, traceback):
"""Reset the db_alias and collection."""
self.cls._meta["db_alias"] = self.ori_db_alias
self.cls._collection = self.collection
self.cls._set_collection(self.collection)


class switch_collection:
"""switch_collection alias context manager.

Warning ::

### This is NOT completely thread-safe ###

Example ::

class Group(Document):
Expand Down Expand Up @@ -92,12 +127,12 @@ def _get_collection_name(cls):
return self.collection_name

self.cls._get_collection_name = _get_collection_name
self.cls._collection = None
self.cls._set_collection(None)
return self.cls

def __exit__(self, t, value, traceback):
"""Reset the collection."""
self.cls._collection = self.ori_collection
self.cls._set_collection(self.ori_collection)
self.cls._get_collection_name = self.ori_get_collection_name


Expand Down
29 changes: 19 additions & 10 deletions mongoengine/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
get_document,
)
from mongoengine.common import _import_class
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db, get_local_db_alias
from mongoengine.context_managers import (
set_write_concern,
switch_collection,
Expand Down Expand Up @@ -196,15 +196,23 @@ def __hash__(self):

return hash(self.pk)

@classmethod
def _get_local_db_alias(cls):
return get_local_db_alias(cls._meta.get("db_alias", DEFAULT_CONNECTION_NAME))

@classmethod
def _get_db(cls):
"""Some Model using other db_alias"""
return get_db(cls._meta.get("db_alias", DEFAULT_CONNECTION_NAME))
return get_db(cls._get_local_db_alias())

@classmethod
def _disconnect(cls):
"""Detach the Document class from the (cached) database collection"""
cls._collection = None
"""Detach the Document class from all (cached) database collections"""
cls._collections = {}

@classmethod
def _set_collection(cls, collection):
cls._collections[cls._get_local_db_alias()] = collection

@classmethod
def _get_collection(cls):
Expand All @@ -216,14 +224,15 @@ def _get_collection(cls):
2. Creates indexes defined in this document's :attr:`meta` dictionary.
This happens only if `auto_create_index` is True.
"""
if not hasattr(cls, "_collection") or cls._collection is None:
local_db_alias = cls._get_local_db_alias()
if local_db_alias not in cls._collections:
# Get the collection, either capped or regular.
if cls._meta.get("max_size") or cls._meta.get("max_documents"):
cls._collection = cls._get_capped_collection()
cls._collections[local_db_alias] = cls._get_capped_collection()
else:
db = cls._get_db()
collection_name = cls._get_collection_name()
cls._collection = db[collection_name]
cls._collections[local_db_alias] = db[collection_name]

# Ensure indexes on the collection unless auto_create_index was
# set to False.
Expand All @@ -232,7 +241,7 @@ def _get_collection(cls):
if cls._meta.get("auto_create_index", True) and db.client.is_primary:
cls.ensure_indexes()

return cls._collection
return cls._collections[local_db_alias]

@classmethod
def _get_capped_collection(cls):
Expand Down Expand Up @@ -260,7 +269,7 @@ def _get_capped_collection(cls):
if options.get("max") != max_documents or options.get("size") != max_size:
raise InvalidCollectionError(
'Cannot create collection "{}" as a capped '
"collection as it already exists".format(cls._collection)
"collection as it already exists".format(collection_name)
)

return collection
Expand Down Expand Up @@ -837,7 +846,7 @@ def drop_collection(cls):
raise OperationError(
"Document %s has no collection defined (is it abstract ?)" % cls
)
cls._collection = None
cls._set_collection(None)
db = cls._get_db()
db.drop_collection(coll_name)

Expand Down
5 changes: 5 additions & 0 deletions mongoengine/errors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections import defaultdict

__all__ = (
"DatabaseAliasError",
"NotRegistered",
"InvalidDocumentError",
"LookUpError",
Expand All @@ -21,6 +22,10 @@ class MongoEngineException(Exception):
pass


class DatabaseAliasError(MongoEngineException):
pass


class NotRegistered(MongoEngineException):
pass

Expand Down
35 changes: 19 additions & 16 deletions mongoengine/queryset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ class BaseQuerySet:
__dereference = False
_auto_dereference = True

def __init__(self, document, collection):
def __init__(self, document, db_alias=None):
self._document = document
self._collection_obj = collection
self._db_alias = db_alias
self._mongo_query = None
self._query_obj = Q()
self._cls_query = {}
Expand All @@ -74,6 +74,8 @@ def __init__(self, document, collection):
self._as_pymongo = False
self._search_text = None

self.__init_using_collection()

# If inheritance is allowed, only return instances and instances of
# subclasses of the class being used
if document._meta.get("allow_inheritance") is True:
Expand All @@ -100,6 +102,12 @@ def __init__(self, document, collection):
# it anytime we change _limit. Inspired by how it is done in pymongo.Cursor
self._empty = False

def __init_using_collection(self):
self._using_collection = None
if self._db_alias is not None:
with switch_db(self._document, self._db_alias) as cls:
self._using_collection = cls._get_collection()

def __call__(self, q_obj=None, **query):
"""Filter the selected documents by calling the
:class:`~mongoengine.queryset.QuerySet` with a query.
Expand Down Expand Up @@ -137,9 +145,6 @@ def __getstate__(self):

obj_dict = self.__dict__.copy()

# don't picke collection, instead pickle collection params
obj_dict.pop("_collection_obj")

# don't pickle cursor
obj_dict["_cursor_obj"] = None

Expand All @@ -152,11 +157,11 @@ def __setstate__(self, obj_dict):
See https://github.com/MongoEngine/mongoengine/issues/442
"""

obj_dict["_collection_obj"] = obj_dict["_document"]._get_collection()

# update attributes
self.__dict__.update(obj_dict)

self.__init_using_collection()

# forse load cursor
# self._cursor

Expand Down Expand Up @@ -494,7 +499,7 @@ def delete(self, write_concern=None, _from_doc_delete=False, cascade_refs=None):
if rule == CASCADE:
cascade_refs = set() if cascade_refs is None else cascade_refs
# Handle recursive reference
if doc._collection == document_cls._collection:
if doc._collection == document_cls._get_collection():
for ref in queryset:
cascade_refs.add(ref.id)
refs = document_cls.objects(
Expand Down Expand Up @@ -777,14 +782,11 @@ def using(self, alias):
:param alias: The database alias
"""

with switch_db(self._document, alias) as cls:
collection = cls._get_collection()

return self._clone_into(self.__class__(self._document, collection))
return self._clone_into(self.__class__(self._document, alias))

def clone(self):
"""Create a copy of the current queryset."""
return self._clone_into(self.__class__(self._document, self._collection_obj))
return self._clone_into(self.__class__(self._document, self._db_alias))

def _clone_into(self, new_qs):
"""Copy all of the relevant properties of this queryset to
Expand Down Expand Up @@ -1531,7 +1533,7 @@ def sum(self, field):
if isinstance(field_instances[-1], ListField):
pipeline.insert(1, {"$unwind": "$" + field})

result = tuple(self._document._get_collection().aggregate(pipeline))
result = tuple(self._collection.aggregate(pipeline))

if result:
return result[0]["total"]
Expand All @@ -1558,7 +1560,7 @@ def average(self, field):
if isinstance(field_instances[-1], ListField):
pipeline.insert(1, {"$unwind": "$" + field})

result = tuple(self._document._get_collection().aggregate(pipeline))
result = tuple(self._collection.aggregate(pipeline))
if result:
return result[0]["total"]
return 0
Expand Down Expand Up @@ -1620,7 +1622,8 @@ def _collection(self):
"""Property that returns the collection object. This allows us to
perform operations only if the collection is accessed.
"""
return self._collection_obj
return self._document._get_collection() \
if self._using_collection is None else self._using_collection

@property
def _cursor_args(self):
Expand Down
2 changes: 1 addition & 1 deletion mongoengine/queryset/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __get__(self, instance, owner):

# owner is the document that contains the QuerySetManager
queryset_class = owner._meta.get("queryset_class", self.default)
queryset = queryset_class(owner, owner._get_collection())
queryset = queryset_class(owner)
if self.get_queryset:
arg_count = self.get_queryset.__code__.co_argcount
if arg_count == 1:
Expand Down