From 7e2009835f9b3ea255f4441bbaabaaa80060ec03 Mon Sep 17 00:00:00 2001 From: Nigel Small Date: Fri, 8 Oct 2021 14:30:56 +0100 Subject: [PATCH] Relaxed rules around subgraph operations; prevented in-place ops (#925) * Relaxed rules around subgraph operations; prevented in-place ops * Fixed to_subgraph methods --- py2neo/cypher/__init__.py | 4 +- py2neo/data.py | 16 +++++- test/unit/test_data.py | 104 +++++++++++++++++++++++++++++++++++--- 3 files changed, 114 insertions(+), 10 deletions(-) diff --git a/py2neo/cypher/__init__.py b/py2neo/cypher/__init__.py index cb8d3961..34c3bb5a 100644 --- a/py2neo/cypher/__init__.py +++ b/py2neo/cypher/__init__.py @@ -307,7 +307,7 @@ def to_subgraph(self): if s is None: s = s_ else: - s |= s_ + s = s | s_ return s def to_ndarray(self, dtype=None, order='K'): @@ -564,7 +564,7 @@ def to_subgraph(self): if s is None: s = value else: - s |= value + s = s | value return s diff --git a/py2neo/data.py b/py2neo/data.py index 3660d5aa..0ad0faef 100644 --- a/py2neo/data.py +++ b/py2neo/data.py @@ -109,8 +109,8 @@ def __init__(self, nodes=None, relationships=None): self.__nodes = frozenset(nodes or []) self.__relationships = frozenset(relationships or []) self.__nodes |= frozenset(chain.from_iterable(r.nodes for r in self.__relationships)) - if not self.__nodes: - raise ValueError("Subgraphs must contain at least one node") + #if not self.__nodes: + # raise ValueError("Subgraphs must contain at least one node") def __repr__(self): return "Subgraph({%s}, {%s})" % (", ".join(map(repr, self.nodes)), @@ -583,6 +583,18 @@ def __or__(self, other): # use the Walkable implementation. return Walkable.__or__(self, other) + def __ior__(self, other): + raise TypeError("In-place union is not permitted for %s objects" % self.__class__.__name__) + + def __iand__(self, other): + raise TypeError("In-place intersection is not permitted for %s objects" % self.__class__.__name__) + + def __isub__(self, other): + raise TypeError("In-place difference is not permitted for %s objects" % self.__class__.__name__) + + def __ixor__(self, other): + raise TypeError("In-place symmetric difference is not permitted for %s objects" % self.__class__.__name__) + @property def graph(self): return self._graph diff --git a/test/unit/test_data.py b/test/unit/test_data.py index 9b2aff2d..a39f2536 100644 --- a/test/unit/test_data.py +++ b/test/unit/test_data.py @@ -19,6 +19,8 @@ from io import StringIO from unittest import TestCase +from _pytest.python_api import raises + from py2neo.cypher import Record from py2neo.data import Subgraph, Walkable, Node, Relationship, Path, walk from py2neo.integration import Table @@ -405,8 +407,9 @@ def test_property_keys(self): assert self.subgraph.keys() == {"name", "age", "since"} def test_empty_subgraph(self): - with self.assertRaises(ValueError): - Subgraph() + s = Subgraph() + assert len(s.nodes) == 0 + assert len(s.relationships) == 0 class WalkableTestCase(TestCase): @@ -869,7 +872,17 @@ def test_can_concatenate_node_and_none(self): class UnionTestCase(TestCase): - def test_graph_union(self): + def test_node_union(self): + s = alice | bob + assert len(s.nodes) == 2 + assert len(s.relationships) == 0 + + def test_node_union_in_place(self): + n = Node() + with raises(TypeError): + n |= alice + + def test_subgraph_union(self): graph_1 = (alice_knows_bob | alice_likes_carol | carol_dislikes_bob) graph_2 = (carol_dislikes_bob | carol_married_to_dave | dave_works_for_dave) graph = graph_1 | graph_2 @@ -877,10 +890,33 @@ def test_graph_union(self): assert len(graph.relationships) == 5 assert graph.nodes == (alice | bob | carol | dave).nodes + def test_subgraph_union_in_place(self): + graph = (alice_knows_bob | alice_likes_carol | carol_dislikes_bob) + graph_2 = (carol_dislikes_bob | carol_married_to_dave | dave_works_for_dave) + graph |= graph_2 + assert len(graph.nodes) == 4 + assert len(graph.relationships) == 5 + assert graph.nodes == (alice | bob | carol | dave).nodes + class IntersectionTestCase(TestCase): - def test_graph_intersection(self): + def test_node_intersection_same(self): + s = alice & alice + assert len(s.nodes) == 1 + assert len(s.relationships) == 0 + + def test_node_intersection_different(self): + s = alice & bob + assert len(s.nodes) == 0 + assert len(s.relationships) == 0 + + def test_node_intersection_in_place(self): + n = Node() + with raises(TypeError): + n &= alice + + def test_subgraph_intersection(self): graph_1 = (alice_knows_bob | alice_likes_carol | carol_dislikes_bob) graph_2 = (carol_dislikes_bob | carol_married_to_dave | dave_works_for_dave) graph = graph_1 & graph_2 @@ -888,10 +924,33 @@ def test_graph_intersection(self): assert len(graph.relationships) == 1 assert graph.nodes == (bob | carol).nodes + def test_subgraph_intersection_in_place(self): + graph = (alice_knows_bob | alice_likes_carol | carol_dislikes_bob) + graph_2 = (carol_dislikes_bob | carol_married_to_dave | dave_works_for_dave) + graph &= graph_2 + assert len(graph.nodes) == 2 + assert len(graph.relationships) == 1 + assert graph.nodes == (bob | carol).nodes + class DifferenceTestCase(TestCase): - def test_graph_difference(self): + def test_node_difference_same(self): + s = alice - alice + assert len(s.nodes) == 0 + assert len(s.relationships) == 0 + + def test_node_difference_different(self): + s = alice - bob + assert len(s.nodes) == 1 + assert len(s.relationships) == 0 + + def test_node_difference_in_place(self): + n = Node() + with raises(TypeError): + n -= alice + + def test_subgraph_difference(self): graph_1 = (alice_knows_bob | alice_likes_carol | carol_dislikes_bob) graph_2 = (carol_dislikes_bob | carol_married_to_dave | dave_works_for_dave) graph = graph_1 - graph_2 @@ -899,10 +958,33 @@ def test_graph_difference(self): assert len(graph.relationships) == 2 assert graph.nodes == (alice | bob | carol).nodes + def test_subgraph_difference_in_place(self): + graph = (alice_knows_bob | alice_likes_carol | carol_dislikes_bob) + graph_2 = (carol_dislikes_bob | carol_married_to_dave | dave_works_for_dave) + graph -= graph_2 + assert len(graph.nodes) == 3 + assert len(graph.relationships) == 2 + assert graph.nodes == (alice | bob | carol).nodes + class SymmetricDifferenceTestCase(TestCase): - def test_graph_symmetric_difference(self): + def test_node_symmetric_difference_same(self): + s = alice ^ alice + assert len(s.nodes) == 0 + assert len(s.relationships) == 0 + + def test_node_symmetric_difference_different(self): + s = alice ^ bob + assert len(s.nodes) == 2 + assert len(s.relationships) == 0 + + def test_node_symmetric_difference_in_place(self): + n = Node() + with raises(TypeError): + n ^= alice + + def test_subgraph_symmetric_difference(self): graph_1 = (alice_knows_bob | alice_likes_carol | carol_dislikes_bob) graph_2 = (carol_dislikes_bob | carol_married_to_dave | dave_works_for_dave) graph = graph_1 ^ graph_2 @@ -912,6 +994,16 @@ def test_graph_symmetric_difference(self): assert graph.relationships == frozenset(alice_knows_bob | alice_likes_carol | carol_married_to_dave | dave_works_for_dave) + def test_subgraph_symmetric_difference_in_place(self): + graph = (alice_knows_bob | alice_likes_carol | carol_dislikes_bob) + graph_2 = (carol_dislikes_bob | carol_married_to_dave | dave_works_for_dave) + graph ^= graph_2 + assert len(graph.nodes) == 4 + assert len(graph.relationships) == 4 + assert graph.nodes == (alice | bob | carol | dave).nodes + assert graph.relationships == frozenset(alice_knows_bob | alice_likes_carol | + carol_married_to_dave | dave_works_for_dave) + def test_record_repr(): person = Record(["name", "age"], ["Alice", 33])