diff --git a/traitlets/tests/test_traitlets.py b/traitlets/tests/test_traitlets.py index 595f75b0..87d09edb 100644 --- a/traitlets/tests/test_traitlets.py +++ b/traitlets/tests/test_traitlets.py @@ -60,7 +60,7 @@ BaseDescriptor, HasDescriptors, CUnicode, -) + SourceLink) from traitlets.utils import cast_unicode @@ -2019,6 +2019,17 @@ def another_update(self, change): l = link((mc, "i"), (mc, "j")) self.assertRaises(TraitError, setattr, mc, 'j', 2) + def test_source_link(self): + class A(HasTraits): + a = Int() + + x1 = A() + x2 = A(a = SourceLink(x1, "a", transform=(lambda x: x+1, lambda x: x-1))) + x1.a = 3 + self.assertEqual(x2.a, 4) + x2.a = 1 + self.assertEqual(x1.a, 0) + class TestDirectionalLink(TestCase): def test_connect_same(self): """Verify two traitlets of the same type can be linked together using directional_link.""" @@ -2110,6 +2121,17 @@ class A(HasTraits): a.value += 1 self.assertEqual(a.value, b.value) + def test_source_directional_link(self): + class A(HasTraits): + a = Int() + + x1 = A() + x2 = A(a = SourceLink(x1, "a", link_type=directional_link, transform=lambda x: x+1)) + x1.a = 3 + self.assertEqual(x2.a, 4) + x2.a = 1 + self.assertEqual(x1.a, 3) + class Pickleable(HasTraits): i = Int() diff --git a/traitlets/traitlets.py b/traitlets/traitlets.py index 20355904..f382bee4 100644 --- a/traitlets/traitlets.py +++ b/traitlets/traitlets.py @@ -39,21 +39,21 @@ # Adapted from enthought.traits, Copyright (c) Enthought, Inc., # also under the terms of the Modified BSD License. -from ast import literal_eval import contextlib +import enum import inspect import os import re import sys import types -import enum +from ast import literal_eval from warnings import warn, warn_explicit +from .utils.bunch import Bunch +from .utils.descriptions import describe, class_of, add_article, repr_type from .utils.getargspec import getargspec from .utils.importstring import import_item from .utils.sentinel import Sentinel -from .utils.bunch import Bunch -from .utils.descriptions import describe, class_of, add_article, repr_type SequenceTypes = (list, tuple, set, frozenset) @@ -81,6 +81,7 @@ "BaseDescriptor", "TraitType", "parse_notifier_name", + "SourceLink" ] # any TraitType subclass (that doesn't start with _) will be added automatically @@ -246,6 +247,7 @@ def _validate_link(*tuples): if not trait_name in obj.traits(): raise TypeError("%r has no trait %r" % (obj, trait_name)) + class link(object): """Link traits from different objects together so they remain in sync. @@ -366,6 +368,14 @@ def unlink(self): dlink = directional_link +class SourceLink: + def __init__(self, obj, attr, link_type=link, transform=None): + self.obj = obj + self.attr = attr + self.link_type = link_type + self.transform = transform + + #----------------------------------------------------------------------------- # Base Descriptor Class #----------------------------------------------------------------------------- @@ -600,6 +610,8 @@ def __set__(self, obj, value): """ if self.read_only: raise TraitError('The "%s" trait is read-only.' % self.name) + elif isinstance(value, SourceLink): + value.link_type((value.obj, value.attr), (obj, self.name), transform=value.transform) else: self.set(obj, value) @@ -731,6 +743,7 @@ def tag(self, **metadata): def default_value_repr(self): return repr(self.default_value) + #----------------------------------------------------------------------------- # The HasTraits implementation #-----------------------------------------------------------------------------