Skip to content

Commit

Permalink
Relaxed rules around subgraph operations; prevented in-place ops (#925)
Browse files Browse the repository at this point in the history
* Relaxed rules around subgraph operations; prevented in-place ops

* Fixed to_subgraph methods
  • Loading branch information
technige authored Oct 8, 2021
1 parent 166a131 commit 7e20098
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 10 deletions.
4 changes: 2 additions & 2 deletions py2neo/cypher/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def to_subgraph(self):
if s is None:
s = s_
else:
s |= s_
s = s | s_
return s

def to_ndarray(self, dtype=None, order='K'):
Expand Down Expand Up @@ -564,7 +564,7 @@ def to_subgraph(self):
if s is None:
s = value
else:
s |= value
s = s | value
return s


Expand Down
16 changes: 14 additions & 2 deletions py2neo/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def __init__(self, nodes=None, relationships=None):
self.__nodes = frozenset(nodes or [])
self.__relationships = frozenset(relationships or [])
self.__nodes |= frozenset(chain.from_iterable(r.nodes for r in self.__relationships))
if not self.__nodes:
raise ValueError("Subgraphs must contain at least one node")
#if not self.__nodes:
# raise ValueError("Subgraphs must contain at least one node")

def __repr__(self):
return "Subgraph({%s}, {%s})" % (", ".join(map(repr, self.nodes)),
Expand Down Expand Up @@ -583,6 +583,18 @@ def __or__(self, other):
# use the Walkable implementation.
return Walkable.__or__(self, other)

def __ior__(self, other):
raise TypeError("In-place union is not permitted for %s objects" % self.__class__.__name__)

def __iand__(self, other):
raise TypeError("In-place intersection is not permitted for %s objects" % self.__class__.__name__)

def __isub__(self, other):
raise TypeError("In-place difference is not permitted for %s objects" % self.__class__.__name__)

def __ixor__(self, other):
raise TypeError("In-place symmetric difference is not permitted for %s objects" % self.__class__.__name__)

@property
def graph(self):
return self._graph
Expand Down
104 changes: 98 additions & 6 deletions test/unit/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from io import StringIO
from unittest import TestCase

from _pytest.python_api import raises

from py2neo.cypher import Record
from py2neo.data import Subgraph, Walkable, Node, Relationship, Path, walk
from py2neo.integration import Table
Expand Down Expand Up @@ -405,8 +407,9 @@ def test_property_keys(self):
assert self.subgraph.keys() == {"name", "age", "since"}

def test_empty_subgraph(self):
with self.assertRaises(ValueError):
Subgraph()
s = Subgraph()
assert len(s.nodes) == 0
assert len(s.relationships) == 0


class WalkableTestCase(TestCase):
Expand Down Expand Up @@ -869,40 +872,119 @@ def test_can_concatenate_node_and_none(self):

class UnionTestCase(TestCase):

def test_graph_union(self):
def test_node_union(self):
s = alice | bob
assert len(s.nodes) == 2
assert len(s.relationships) == 0

def test_node_union_in_place(self):
n = Node()
with raises(TypeError):
n |= alice

def test_subgraph_union(self):
graph_1 = (alice_knows_bob | alice_likes_carol | carol_dislikes_bob)
graph_2 = (carol_dislikes_bob | carol_married_to_dave | dave_works_for_dave)
graph = graph_1 | graph_2
assert len(graph.nodes) == 4
assert len(graph.relationships) == 5
assert graph.nodes == (alice | bob | carol | dave).nodes

def test_subgraph_union_in_place(self):
graph = (alice_knows_bob | alice_likes_carol | carol_dislikes_bob)
graph_2 = (carol_dislikes_bob | carol_married_to_dave | dave_works_for_dave)
graph |= graph_2
assert len(graph.nodes) == 4
assert len(graph.relationships) == 5
assert graph.nodes == (alice | bob | carol | dave).nodes


class IntersectionTestCase(TestCase):

def test_graph_intersection(self):
def test_node_intersection_same(self):
s = alice & alice
assert len(s.nodes) == 1
assert len(s.relationships) == 0

def test_node_intersection_different(self):
s = alice & bob
assert len(s.nodes) == 0
assert len(s.relationships) == 0

def test_node_intersection_in_place(self):
n = Node()
with raises(TypeError):
n &= alice

def test_subgraph_intersection(self):
graph_1 = (alice_knows_bob | alice_likes_carol | carol_dislikes_bob)
graph_2 = (carol_dislikes_bob | carol_married_to_dave | dave_works_for_dave)
graph = graph_1 & graph_2
assert len(graph.nodes) == 2
assert len(graph.relationships) == 1
assert graph.nodes == (bob | carol).nodes

def test_subgraph_intersection_in_place(self):
graph = (alice_knows_bob | alice_likes_carol | carol_dislikes_bob)
graph_2 = (carol_dislikes_bob | carol_married_to_dave | dave_works_for_dave)
graph &= graph_2
assert len(graph.nodes) == 2
assert len(graph.relationships) == 1
assert graph.nodes == (bob | carol).nodes


class DifferenceTestCase(TestCase):

def test_graph_difference(self):
def test_node_difference_same(self):
s = alice - alice
assert len(s.nodes) == 0
assert len(s.relationships) == 0

def test_node_difference_different(self):
s = alice - bob
assert len(s.nodes) == 1
assert len(s.relationships) == 0

def test_node_difference_in_place(self):
n = Node()
with raises(TypeError):
n -= alice

def test_subgraph_difference(self):
graph_1 = (alice_knows_bob | alice_likes_carol | carol_dislikes_bob)
graph_2 = (carol_dislikes_bob | carol_married_to_dave | dave_works_for_dave)
graph = graph_1 - graph_2
assert len(graph.nodes) == 3
assert len(graph.relationships) == 2
assert graph.nodes == (alice | bob | carol).nodes

def test_subgraph_difference_in_place(self):
graph = (alice_knows_bob | alice_likes_carol | carol_dislikes_bob)
graph_2 = (carol_dislikes_bob | carol_married_to_dave | dave_works_for_dave)
graph -= graph_2
assert len(graph.nodes) == 3
assert len(graph.relationships) == 2
assert graph.nodes == (alice | bob | carol).nodes


class SymmetricDifferenceTestCase(TestCase):

def test_graph_symmetric_difference(self):
def test_node_symmetric_difference_same(self):
s = alice ^ alice
assert len(s.nodes) == 0
assert len(s.relationships) == 0

def test_node_symmetric_difference_different(self):
s = alice ^ bob
assert len(s.nodes) == 2
assert len(s.relationships) == 0

def test_node_symmetric_difference_in_place(self):
n = Node()
with raises(TypeError):
n ^= alice

def test_subgraph_symmetric_difference(self):
graph_1 = (alice_knows_bob | alice_likes_carol | carol_dislikes_bob)
graph_2 = (carol_dislikes_bob | carol_married_to_dave | dave_works_for_dave)
graph = graph_1 ^ graph_2
Expand All @@ -912,6 +994,16 @@ def test_graph_symmetric_difference(self):
assert graph.relationships == frozenset(alice_knows_bob | alice_likes_carol |
carol_married_to_dave | dave_works_for_dave)

def test_subgraph_symmetric_difference_in_place(self):
graph = (alice_knows_bob | alice_likes_carol | carol_dislikes_bob)
graph_2 = (carol_dislikes_bob | carol_married_to_dave | dave_works_for_dave)
graph ^= graph_2
assert len(graph.nodes) == 4
assert len(graph.relationships) == 4
assert graph.nodes == (alice | bob | carol | dave).nodes
assert graph.relationships == frozenset(alice_knows_bob | alice_likes_carol |
carol_married_to_dave | dave_works_for_dave)


def test_record_repr():
person = Record(["name", "age"], ["Alice", 33])
Expand Down

0 comments on commit 7e20098

Please sign in to comment.