diff --git a/objectbox/box.py b/objectbox/box.py index c9da1d6..d013aca 100644 --- a/objectbox/box.py +++ b/objectbox/box.py @@ -150,6 +150,15 @@ def remove_all(self) -> int: obx_box_remove_all(self._c_box, ctypes.byref(count)) return int(count.value) - def query(self) -> QueryBuilder: - """ Creates a QueryBuilder for the Entity managed by the Box. """ - return QueryBuilder(self._ob, self) + def query(self, condition: Optional[QueryCondition] = None) -> QueryBuilder: + """ Creates a QueryBuilder for the Entity that is managed by the Box. + + :param condition: + If given, applies the given high-level condition to the new QueryBuilder object. + Useful for a user-friendly API design; for example: + ``box.query(name_property.equals("Johnny")).build()`` + """ + qb = QueryBuilder(self._ob, self) + if condition is not None: + condition.apply(qb) + return qb diff --git a/objectbox/condition.py b/objectbox/condition.py index 11d2485..504b7f8 100644 --- a/objectbox/condition.py +++ b/objectbox/condition.py @@ -1,95 +1,162 @@ from enum import Enum +from typing import * -class _ConditionOp(Enum): - eq = 1 - notEq = 2 - contains = 3 - startsWith = 4 - endsWith = 5 - gt = 6 - greaterOrEq = 7 - lt = 8 - lessOrEq = 9 - between = 10 + +class _QueryConditionOp(Enum): + EQ = 1 + NOT_EQ = 2 + CONTAINS = 3 + STARTS_WITH = 4 + ENDS_WITH = 5 + GT = 6 + GTE = 7 + LT = 8 + LTE = 9 + BETWEEN = 10 + NEAREST_NEIGHBOR = 11 class QueryCondition: - def __init__(self, property_id: int, op: _ConditionOp, value, value_b = None, case_sensitive: bool = True): + def __init__(self, property_id: int, op: _QueryConditionOp, args: Dict[str, Any]): + if op not in self._get_op_map(): + raise Exception(f"Invalid query condition op with ID: {op}") + self._property_id = property_id self._op = op - self._value = value - self._value_b = value_b - self._case_sensitive = case_sensitive - - def apply(self, builder: 'QueryBuilder'): - if self._op == _ConditionOp.eq: - if isinstance(self._value, str): - builder.equals_string(self._property_id, self._value, self._case_sensitive) - elif isinstance(self._value, int): - builder.equals_int(self._property_id, self._value) - else: - raise Exception("Unsupported type for 'eq': " + str(type(self._value))) - - elif self._op == _ConditionOp.notEq: - if isinstance(self._value, str): - builder.not_equals_string(self._property_id, self._value, self._case_sensitive) - elif isinstance(self._value, int): - builder.not_equals_int(self._property_id, self._value) - else: - raise Exception("Unsupported type for 'notEq': " + str(type(self._value))) - - elif self._op == _ConditionOp.contains: - if isinstance(self._value, str): - builder.contains_string(self._property_id, self._value, self._case_sensitive) - else: - raise Exception("Unsupported type for 'contains': " + str(type(self._value))) - - elif self._op == _ConditionOp.startsWith: - if isinstance(self._value, str): - builder.starts_with_string(self._property_id, self._value, self._case_sensitive) - else: - raise Exception("Unsupported type for 'startsWith': " + str(type(self._value))) - - elif self._op == _ConditionOp.endsWith: - if isinstance(self._value, str): - builder.ends_with_string(self._property_id, self._value, self._case_sensitive) - else: - raise Exception("Unsupported type for 'endsWith': " + str(type(self._value))) - - elif self._op == _ConditionOp.gt: - if isinstance(self._value, str): - builder.greater_than_string(self._property_id, self._value, self._case_sensitive) - elif isinstance(self._value, int): - builder.greater_than_int(self._property_id, self._value) - else: - raise Exception("Unsupported type for 'gt': " + str(type(self._value))) - - elif self._op == _ConditionOp.greaterOrEq: - if isinstance(self._value, str): - builder.greater_or_equal_string(self._property_id, self._value, self._case_sensitive) - elif isinstance(self._value, int): - builder.greater_or_equal_int(self._property_id, self._value) - else: - raise Exception("Unsupported type for 'greaterOrEq': " + str(type(self._value))) - - elif self._op == _ConditionOp.lt: - if isinstance(self._value, str): - builder.less_than_string(self._property_id, self._value, self._case_sensitive) - elif isinstance(self._value, int): - builder.less_than_int(self._property_id, self._value) - else: - raise Exception("Unsupported type for 'lt': " + str(type(self._value))) - - elif self._op == _ConditionOp.lessOrEq: - if isinstance(self._value, str): - builder.less_or_equal_string(self._property_id, self._value, self._case_sensitive) - elif isinstance(self._value, int): - builder.less_or_equal_int(self._property_id, self._value) - else: - raise Exception("Unsupported type for 'lessOrEq': " + str(type(self._value))) - - elif self._op == _ConditionOp.between: - if isinstance(self._value, int): - builder.between_2ints(self._property_id, self._value, self._value_b) - else: - raise Exception("Unsupported type for 'between': " + str(type(self._value))) \ No newline at end of file + self._args = args + + def _get_op_map(self): + return { + _QueryConditionOp.EQ: self._apply_eq, + _QueryConditionOp.NOT_EQ: self._apply_not_eq, + _QueryConditionOp.CONTAINS: self._apply_contains, + _QueryConditionOp.STARTS_WITH: self._apply_starts_with, + _QueryConditionOp.ENDS_WITH: self._apply_ends_with, + _QueryConditionOp.GT: self._apply_gt, + _QueryConditionOp.GTE: self._apply_gte, + _QueryConditionOp.LT: self._apply_lt, + _QueryConditionOp.LTE: self._apply_lte, + _QueryConditionOp.BETWEEN: self._apply_between, + _QueryConditionOp.NEAREST_NEIGHBOR: self._apply_nearest_neighbor + # ... new query condition here ... :) + } + + def _apply_eq(self, qb: 'QueryBuilder'): + value = self._args['value'] + case_sensitive = self._args['case_sensitive'] + if isinstance(value, str): + qb.equals_string(self._property_id, value, case_sensitive) + elif isinstance(value, int): + qb.equals_int(self._property_id, value) + else: + raise Exception(f"Unsupported type for 'EQ': {type(value)}") + + def _apply_not_eq(self, qb: 'QueryBuilder'): + value = self._args['value'] + case_sensitive = self._args['case_sensitive'] + if isinstance(value, str): + qb.not_equals_string(self._property_id, value, case_sensitive) + elif isinstance(value, int): + qb.not_equals_int(self._property_id, value) + else: + raise Exception(f"Unsupported type for 'NOT_EQ': {type(value)}") + + def _apply_contains(self, qb: 'QueryBuilder'): + value = self._args['value'] + case_sensitive = self._args['case_sensitive'] + if isinstance(value, str): + qb.contains_string(self._property_id, value, case_sensitive) + else: + raise Exception(f"Unsupported type for 'CONTAINS': {type(self_value)}") + + def _apply_starts_with(self, qb: 'QueryBuilder'): + value = self._args['value'] + case_sensitive = self._args['case_sensitive'] + if isinstance(value, str): + qb.starts_with_string(self._property_id, value, case_sensitive) + else: + raise Exception(f"Unsupported type for 'STARTS_WITH': {type(value)}") + + def _apply_ends_with(self, qb: 'QueryBuilder'): + value = self._args['value'] + case_sensitive = self._args['case_sensitive'] + if isinstance(value, str): + qb.ends_with_string(self._property_id, value, case_sensitive) + else: + raise Exception(f"Unsupported type for 'ENDS_WITH': {type(value)}") + + def _apply_gt(self, qb: 'QueryBuilder'): + value = self._args['value'] + case_sensitive = self._args['case_sensitive'] + if isinstance(value, str): + qb.greater_than_string(self._property_id, value, case_sensitive) + elif isinstance(value, int): + qb.greater_than_int(self._property_id, value) + else: + raise Exception(f"Unsupported type for 'GT': {type(value)}") + + def _apply_gt(self, qb: 'QueryBuilder'): + value = self._args['value'] + case_sensitive = self._args['case_sensitive'] + if isinstance(value, str): + qb.greater_than_string(self._property_id, value, case_sensitive) + elif isinstance(value, int): + qb.greater_than_int(self._property_id, value) + else: + raise Exception(f"Unsupported type for 'GT': {type(value)}") + + def _apply_gte(self, qb: 'QueryBuilder'): + value = self._args['value'] + case_sensitive = self._args['case_sensitive'] + if isinstance(value, str): + qb.greater_or_equal_string(self._property_id, value, case_sensitive) + elif isinstance(value, int): + qb.greater_or_equal_int(self._property_id, value) + else: + raise Exception(f"Unsupported type for 'GTE': {type(value)}") + + def _apply_lt(self, qb: 'QueryCondition'): + value = self._args['value'] + case_sensitive = self._args['case_sensitive'] + if isinstance(value, str): + qb.less_than_string(self._property_id, value, case_sensitive) + elif isinstance(value, int): + qb.less_than_int(self._property_id, value) + else: + raise Exception("Unsupported type for 'LT': " + str(type(value))) + + def _apply_lte(self, qb: 'QueryBuilder'): + value = self._args['value'] + case_sensitive = self._args['case_sensitive'] + if isinstance(value, str): + qb.less_or_equal_string(self._property_id, value, case_sensitive) + elif isinstance(value, int): + qb.less_or_equal_int(self._property_id, value) + else: + raise Exception(f"Unsupported type for 'LTE': {type(value)}") + + def _apply_between(self, qb: 'QueryBuilder'): + a = self._args['a'] + b = self._args['b'] + if isinstance(a, int): + qb.between_2ints(self._property_id, a, b) + else: + raise Exception(f"Unsupported type for 'BETWEEN': {type(a)}") + + def _apply_nearest_neighbor(self, qb: 'QueryCondition'): + query_vector = self._args['query_vector'] + element_count = self._args['element_count'] + + if len(query_vector) == 0: + raise Exception("query_vector can't be empty") + + is_float_vector = False + is_float_vector |= isinstance(query_vector, np.ndarray) and query_vector.dtype == np.float32 + is_float_vector |= isinstance(query_vector, list) and type(query_vector[0]) == float + if is_float_vector: + qb.nearest_neighbors_f32(self._property_id, query_vector, element_count) + else: + raise Exception(f"Unsupported type for 'NEAREST_NEIGHBOR': {type(query_vector)}") + + def apply(self, qb: 'QueryBuilder'): + self._get_op_map()[self._op](qb) diff --git a/objectbox/model/properties.py b/objectbox/model/properties.py index e55b407..3fc6a64 100644 --- a/objectbox/model/properties.py +++ b/objectbox/model/properties.py @@ -14,7 +14,7 @@ from enum import IntEnum -from objectbox.condition import QueryCondition, _ConditionOp +from objectbox.condition import QueryCondition, _QueryConditionOp from objectbox.c import * import flatbuffers.number_types import numpy as np @@ -160,39 +160,50 @@ def _set_flags(self): if isinstance(self._index, Index): # Generic index self._flags |= self._index.type - def op(self, op: _ConditionOp, value, case_sensitive: bool = True) -> QueryCondition: - return QueryCondition(self._id, op, value, case_sensitive) - def equals(self, value, case_sensitive: bool = True) -> QueryCondition: - return self.op(_ConditionOp.eq, value, case_sensitive) - + args = {'value': value, 'case_sensitive': case_sensitive} + return QueryCondition(self._id, _QueryConditionOp.EQ, args) + def not_equals(self, value, case_sensitive: bool = True) -> QueryCondition: - return self.op(_ConditionOp.notEq, value, case_sensitive) - + args = {'value': value, 'case_sensitive': case_sensitive} + return QueryCondition(self._id, _QueryConditionOp.NOT_EQ, args) + def contains(self, value: str, case_sensitive: bool = True) -> QueryCondition: - return self.op(_ConditionOp.contains, value, case_sensitive) - + args = {'value': value, 'case_sensitive': case_sensitive} + return QueryCondition(self._id, _QueryConditionOp.CONTAINS, args) + def starts_with(self, value: str, case_sensitive: bool = True) -> QueryCondition: - return self.op(_ConditionOp.startsWith, value, case_sensitive) - + args = {'value': value, 'case_sensitive': case_sensitive} + return QueryCondition(self._id, _QueryConditionOp.STARTS_WITH, args) + def ends_with(self, value: str, case_sensitive: bool = True) -> QueryCondition: - return self.op(_ConditionOp.endsWith, value, case_sensitive) - + args = {'value': value, 'case_sensitive': case_sensitive} + return QueryCondition(self._id, _QueryConditionOp.ENDS_WITH, args) + def greater_than(self, value, case_sensitive: bool = True) -> QueryCondition: - return self.op(_ConditionOp.gt, value, case_sensitive) - + args = {'value': value, 'case_sensitive': case_sensitive} + return QueryCondition(self._id, _QueryConditionOp.GT, args) + def greater_or_equal(self, value, case_sensitive: bool = True) -> QueryCondition: - return self.op(_ConditionOp.greaterOrEq, value, case_sensitive) - + args = {'value': value, 'case_sensitive': case_sensitive} + return QueryCondition(self._id, _QueryConditionOp.GTE, args) + def less_than(self, value, case_sensitive: bool = True) -> QueryCondition: - return self.op(_ConditionOp.lt, value, case_sensitive) - + args = {'value': value, 'case_sensitive': case_sensitive} + return QueryCondition(self._id, _QueryConditionOp.LT, args) + def less_or_equal(self, value, case_sensitive: bool = True) -> QueryCondition: - return self.op(_ConditionOp.lessOrEq, value, case_sensitive) - - def between(self, value_a, value_b) -> QueryCondition: - return QueryCondition(self._id, _ConditionOp.between, value_a, value_b) - + args = {'value': value, 'case_sensitive': case_sensitive} + return QueryCondition(self._id, _QueryConditionOp.LTE, args) + + def between(self, a, b) -> QueryCondition: + args = {'a': a, 'b': b} + return QueryCondition(self._id, _QueryConditionOp.BETWEEN, args) + + def nearest_neighbor(self, query_vector, element_count: int): + args = {'query_vector': query_vector, 'element_count': element_count} + return QueryCondition(self._id, _QueryConditionOp.NEAREST_NEIGHBOR, args) + # ID property (primary key) class Id(Property): diff --git a/tests/test_query.py b/tests/test_query.py index c59cc86..84dcd26 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -10,71 +10,48 @@ def test_basics(): ob = load_empty_test_objectbox() + box = objectbox.Box(ob, TestEntity) - object1 = TestEntity() - object1.str = "foo" - object1.int64 = 123 - object2 = TestEntity() - object2.str = "bar" - object2.int64 = 456 - id1 = box.put(object1) - box.put(object2) - - # String queries + box.put(TestEntity(str="foo", int64=123)) + box.put(TestEntity(str="bar", int64=456)) + + # String query str_prop: Property = TestEntity.get_property("str") - qb = box.query() - qb.equals_string(str_prop._id, "bar", True) - query = qb.build() + query = box.query(str_prop.equals("bar", case_sensitive=True)).build() assert query.count() == 1 assert query.find()[0].str == "bar" - qb = box.query() - qb.not_equals_string(str_prop._id, "bar", True) - query = qb.build() + query = box.query(str_prop.not_equals("bar", case_sensitive=True)).build() assert query.count() == 1 assert query.find()[0].str == "foo" - qb = box.query() - qb.contains_string(str_prop._id, "ba", True) - query = qb.build() + query = box.query(str_prop.contains("ba", case_sensitive=True)).build() assert query.count() == 1 assert query.find()[0].str == "bar" - qb = box.query() - qb.starts_with_string(str_prop._id, "f", True) - query = qb.build() + query = box.query(str_prop.starts_with("f", case_sensitive=True)).build() assert query.count() == 1 assert query.find()[0].str == "foo" - qb = box.query() - qb.ends_with_string(str_prop._id, "o", True) - query = qb.build() + query = box.query(str_prop.ends_with("o", case_sensitive=True)).build() assert query.count() == 1 assert query.find()[0].str == "foo" - qb = box.query() - qb.greater_than_string(str_prop._id, "bar", True) - query = qb.build() + query = box.query(str_prop.greater_than("bar", case_sensitive=True)).build() assert query.count() == 1 assert query.find()[0].str == "foo" - qb = box.query() - qb.greater_or_equal_string(str_prop._id, "bar", True) - query = qb.build() + query = box.query(str_prop.greater_or_equal("bar", case_sensitive=True)).build() assert query.count() == 2 assert query.find()[0].str == "foo" assert query.find()[1].str == "bar" - qb = box.query() - qb.less_than_string(str_prop._id, "foo", True) - query = qb.build() + query = box.query(str_prop.less_than("foo", case_sensitive=True)).build() assert query.count() == 1 assert query.find()[0].str == "bar" - qb = box.query() - qb.less_or_equal_string(str_prop._id, "foo", True) - query = qb.build() + query = box.query(str_prop.less_or_equal("foo", case_sensitive=True)).build() assert query.count() == 2 assert query.find()[0].str == "foo" assert query.find()[1].str == "bar" @@ -82,47 +59,33 @@ def test_basics(): # Int queries int_prop: Property = TestEntity.get_property("int64") - qb = box.query() - qb.equals_int(int_prop._id, 123) - query = qb.build() + query = box.query(int_prop.equals(123)).build() assert query.count() == 1 assert query.find()[0].int64 == 123 - qb = box.query() - qb.not_equals_int(int_prop._id, 123) - query = qb.build() + query = box.query(int_prop.not_equals(123)).build() assert query.count() == 1 assert query.find()[0].int64 == 456 - qb = box.query() - qb.greater_than_int(int_prop._id, 123) - query = qb.build() + query = box.query(int_prop.greater_than(123)).build() assert query.count() == 1 assert query.find()[0].int64 == 456 - qb = box.query() - qb.greater_or_equal_int(int_prop._id, 123) - query = qb.build() + query = box.query(int_prop.greater_or_equal(123)).build() assert query.count() == 2 assert query.find()[0].int64 == 123 assert query.find()[1].int64 == 456 - qb = box.query() - qb.less_than_int(int_prop._id, 456) - query = qb.build() + query = box.query(int_prop.less_than(456)).build() assert query.count() == 1 assert query.find()[0].int64 == 123 - qb = box.query() - qb.less_or_equal_int(int_prop._id, 456) - query = qb.build() + query = box.query(int_prop.less_or_equal(456)).build() assert query.count() == 2 assert query.find()[0].int64 == 123 assert query.find()[1].int64 == 456 - qb = box.query() - qb.between_2ints(int_prop._id, 100, 200) - query = qb.build() + query = box.query(int_prop.between(100, 200)).build() assert query.count() == 1 assert query.find()[0].int64 == 123 @@ -133,21 +96,17 @@ def test_basics(): def test_offset_limit(): ob = load_empty_test_objectbox() + box = objectbox.Box(ob, TestEntity) - object0 = TestEntity() - object1 = TestEntity() - object1.str = "a" - object2 = TestEntity() - object2.str = "b" - object3 = TestEntity() - object3.str = "c" - box.put([object0, object1, object2, object3]) + box.put(TestEntity()) + box.put(TestEntity(str="a")) + box.put(TestEntity(str="b")) + box.put(TestEntity(str="c")) + assert box.count() == 4 - int_prop: Property = TestEntity.get_property("int64") + int_prop = TestEntity.get_property("int64") - qb = box.query() - qb.equals_int(int_prop._id, 0) - query = qb.build() + query = box.query(int_prop.equals(0)).build() assert query.count() == 4 query.offset(2)