From 14872940bcd2b4101e8e81632b02f64d6e4279c3 Mon Sep 17 00:00:00 2001 From: David Gengenbach Date: Sun, 8 Nov 2020 13:53:42 +0100 Subject: [PATCH] Add support to manage multiple switches --- .dockerignore | 1 + .gitignore | 1 + README.md | 10 +- p4runtime_sh/p4runtime.py | 3 + p4runtime_sh/shell.py | 313 ++++++++++++++++++++++---------------- p4runtime_sh/test.py | 6 +- 6 files changed, 196 insertions(+), 138 deletions(-) diff --git a/.dockerignore b/.dockerignore index 9414382..16d9a6c 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1 +1,2 @@ Dockerfile +venv/ diff --git a/.gitignore b/.gitignore index 29a43b1..7dbf18f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ *~ *.pyc .coverage +venv/ diff --git a/README.md b/README.md index eaaa2c0..bff2c1e 100644 --- a/README.md +++ b/README.md @@ -272,22 +272,24 @@ import p4runtime_sh.shell as sh # you can omit the config argument if the switch is already configured with the # correct P4 dataplane. -sh.setup( +client = sh.setup( device_id=1, grpc_addr='localhost:50051', election_id=(0, 1), # (high, low) - config=sh.FwdPipeConfig('config/p4info.pb.txt', 'config/device_config.bin') + config=sh.FwdPipeConfig('config/p4info.pb.txt', 'config/device_config.bin'), + # This enables controlling multiple switches + set_global_client=False ) # see p4runtime_sh/test.py for more examples -te = sh.TableEntry('')(action='') +te = sh.TableEntry('', client=client)(action='') te.match[''] = '' te.action[''] = '' te.insert() # ... -sh.teardown() +sh.teardown(client) ``` Note that at the moment the P4Runtime client object is a global variable, which diff --git a/p4runtime_sh/p4runtime.py b/p4runtime_sh/p4runtime.py index 8bd1339..a7359cd 100644 --- a/p4runtime_sh/p4runtime.py +++ b/p4runtime_sh/p4runtime.py @@ -26,6 +26,8 @@ from p4.v1 import p4runtime_pb2 from p4.v1 import p4runtime_pb2_grpc +from p4runtime_sh.context import Context + class P4RuntimeErrorFormatException(Exception): def __init__(self, message): @@ -137,6 +139,7 @@ def handle(*args, **kwargs): class P4RuntimeClient: def __init__(self, device_id, grpc_addr, election_id): + self.context = Context() self.device_id = device_id self.election_id = election_id logging.debug("Connecting to device {} at {}".format(device_id, grpc_addr)) diff --git a/p4runtime_sh/shell.py b/p4runtime_sh/shell.py index b10e5a0..833a5f7 100644 --- a/p4runtime_sh/shell.py +++ b/p4runtime_sh/shell.py @@ -26,14 +26,24 @@ from p4.v1 import p4runtime_pb2 from p4.config.v1 import p4info_pb2 from . import bytes_utils -from .context import P4RuntimeEntity, P4Type, Context +from .context import P4RuntimeEntity, P4Type from .utils import UserError, InvalidP4InfoError import google.protobuf.text_format from google.protobuf import descriptor +global_client = None # type: P4RuntimeClient -context = Context() -client = None + +def _get_client(client=None): + if client is None and global_client is None: + logging.error( + 'Called sh._get_client without either giving (1) a client argument or (2) a global client defined.\n' + 'Please call sh.setup(..., set_global_client=True) or provide a client=some_client argument.\n' + 'Ignoring this error will lead to unwanted consequences!' + ) + return None + + return global_client if client is None else client class UserUsageError(UserError): @@ -53,7 +63,8 @@ def __str__(self): class _PrintContext: - def __init__(self): + def __init__(self, client=None): + self._client = _get_client(client) self.skip_one = False self.stack = [] @@ -61,7 +72,7 @@ def find_table(self): for msg in reversed(self.stack): if msg.DESCRIPTOR.name == "TableEntry": try: - return context.get_name_from_id(msg.table_id) + return self._client.context.get_name_from_id(msg.table_id) except KeyError: return None return None @@ -70,36 +81,36 @@ def find_action(self): for msg in reversed(self.stack): if msg.DESCRIPTOR.name == "Action": try: - return context.get_name_from_id(msg.action_id) + return self._client.context.get_name_from_id(msg.action_id) except KeyError: return None return None -def _sub_object(field, value, pcontext): +def _sub_object(field, value, pcontext, client=None): id_ = value try: - return context.get_name_from_id(id_) + return _get_client(client).context.get_name_from_id(id_) except KeyError: logging.error("Unknown object id {}".format(id_)) -def _sub_mf(field, value, pcontext): +def _sub_mf(field, value, pcontext, client=None): id_ = value table_name = pcontext.find_table() if table_name is None: logging.error("Cannot find any table in context") return - return context.get_mf_name(table_name, id_) + return _get_client(client).context.get_mf_name(table_name, id_) -def _sub_ap(field, value, pcontext): +def _sub_ap(field, value, pcontext, client=None): id_ = value action_name = pcontext.find_action() if action_name is None: logging.error("Cannot find any action in context") return - return context.get_param_name(action_name, id_) + return _get_client(client).context.get_param_name(action_name, id_) def _gen_pretty_print_proto_field(substitutions, pcontext): @@ -124,11 +135,12 @@ def myPrintField(self, field, value): return myPrintField -def _repr_pretty_proto(msg, substitutions): +def _repr_pretty_proto(msg, substitutions, client=None): """A custom version of google.protobuf.text_format.MessageToString which represents Protobuf messages with a more user-friendly string. In particular, P4Runtime ids are supplemented with the P4 name and binary strings are displayed in hexadecimal format.""" - pcontext = _PrintContext() + client = _get_client(client) + pcontext = _PrintContext(client=client) def message_formatter(message, indent, as_one_line): # For each messages we do 2 passes: the first one updates the _PrintContext instance and @@ -157,41 +169,52 @@ def message_formatter(message, indent, as_one_line): return s +def get_sub_fn(fn, client=None): + client = _get_client(client) + return lambda field, value, pcontext: fn(field, value, pcontext, client=client) -def _repr_pretty_p4info(msg): + +def _repr_pretty_p4info(msg, client=None): + client = _get_client(client) + __sub_object = get_sub_fn(_sub_object, client=client) substitutions = { - "Table": {"const_default_action_id": _sub_object, - "implementation_id": _sub_object, - "direct_resource_ids": _sub_object}, - "ActionRef": {"id": _sub_object}, - "ActionProfile": {"table_ids": _sub_object}, - "DirectCounter": {"direct_table_id": _sub_object}, - "DirectMeter": {"direct_table_id": _sub_object}, + "Table": {"const_default_action_id": __sub_object, + "implementation_id": __sub_object, + "direct_resource_ids": __sub_object}, + "ActionRef": {"id": __sub_object}, + "ActionProfile": {"table_ids": __sub_object}, + "DirectCounter": {"direct_table_id": __sub_object}, + "DirectMeter": {"direct_table_id": __sub_object}, } - return _repr_pretty_proto(msg, substitutions) + return _repr_pretty_proto(msg, substitutions, client=client) -def _repr_pretty_p4runtime(msg): +def _repr_pretty_p4runtime(msg, client=None): + client = _get_client(client) + ___sub_object = get_sub_fn(_sub_object, client=client) + ___sub_mf = get_sub_fn(_sub_mf, client=client) + ___sub_ap = get_sub_fn(_sub_ap, client=client) substitutions = { - "TableEntry": {"table_id": _sub_object}, - "FieldMatch": {"field_id": _sub_mf}, - "Action": {"action_id": _sub_object}, - "Param": {"param_id": _sub_ap}, - "ActionProfileMember": {"action_profile_id": _sub_object}, - "ActionProfileGroup": {"action_profile_id": _sub_object}, - "MeterEntry": {"meter_id": _sub_object}, - "CounterEntry": {"counter_id": _sub_object}, - "ValueSetEntry": {"value_set_id": _sub_object}, - "RegisterEntry": {"register_id": _sub_object}, - "DigestEntry": {"digest_id": _sub_object}, - "DigestListAck": {"digest_id": _sub_object}, - "DigestList": {"digest_id": _sub_object}, + "TableEntry": {"table_id": ___sub_object}, + "FieldMatch": {"field_id": ___sub_mf}, + "Action": {"action_id": ___sub_object}, + "Param": {"param_id": ___sub_ap}, + "ActionProfileMember": {"action_profile_id": ___sub_object}, + "ActionProfileGroup": {"action_profile_id": ___sub_object}, + "MeterEntry": {"meter_id": ___sub_object}, + "CounterEntry": {"counter_id": ___sub_object}, + "ValueSetEntry": {"value_set_id": ___sub_object}, + "RegisterEntry": {"register_id": ___sub_object}, + "DigestEntry": {"digest_id": ___sub_object}, + "DigestListAck": {"digest_id": ___sub_object}, + "DigestList": {"digest_id": ___sub_object}, } - return _repr_pretty_proto(msg, substitutions) + return _repr_pretty_proto(msg, substitutions, client=client) class P4Object: - def __init__(self, obj_type, obj): + def __init__(self, obj_type, obj, client=None): + self._client = _get_client(client) self.name = obj.preamble.name self.id = obj.preamble.id self._obj_type = obj_type @@ -211,10 +234,10 @@ def __dir__(self): return d def _repr_pretty_(self, p, cycle): - p.text(_repr_pretty_p4info(self._obj)) + p.text(_repr_pretty_p4info(self._obj, client=self._client)) def __str__(self): - return _repr_pretty_p4info(self._obj) + return _repr_pretty_p4info(self._obj, client=self._client) def __getattr__(self, name): return getattr(self._obj, name) @@ -227,27 +250,28 @@ def msg(self): return self._obj def info(self): - print(_repr_pretty_p4info(self._obj)) + print(_repr_pretty_p4info(self._obj, client=self._client)) def actions(self): """Print list of actions, only for tables and action profiles.""" if self._obj_type == P4Type.table: for action in self._obj.action_refs: - print(context.get_name_from_id(action.id)) + print(self._client.context.get_name_from_id(action.id)) elif self._obj_type == P4Type.action_profile: t_id = self._obj.table_ids[0] - t_name = context.get_name_from_id(t_id) - t = context.get_table(t_name) + t_name = self._client.context.get_name_from_id(t_id) + t = self._client.context.get_table(t_name) for action in t.action_refs: - print(context.get_name_from_id(action.id)) + print(self._client.context.get_name_from_id(action.id)) else: raise UserError("'actions' is only available for tables and action profiles") class P4Objects: - def __init__(self, obj_type): + def __init__(self, obj_type, client=None): + self._client = _get_client(client) self._obj_type = obj_type - self._names = sorted([name for name, _ in context.get_objs(obj_type)]) + self._names = sorted([name for name, _ in self._client.context.get_objs(obj_type)]) self._iter = None self.__doc__ = """ All the {pnames} in the P4 program. @@ -265,11 +289,11 @@ def _ipython_key_completions_(self): return self._names def __getitem__(self, name): - obj = context.get_obj(self._obj_type, name) + obj = self._client.context.get_obj(self._obj_type, name) if obj is None: raise UserError("{} '{}' does not exist".format( self._obj_type.pretty_name, name)) - return P4Object(self._obj_type, obj) + return P4Object(self._obj_type, obj, client=self._client) def __setitem__(self, name, value): raise UserError("Operation not allowed") @@ -508,12 +532,13 @@ def _count(self): class Action: - def __init__(self, action_name=None): + def __init__(self, action_name=None, client=None): + self._client = _get_client(client) self._init = False if action_name is None: raise UserError("Please provide name for action") self.action_name = action_name - action_info = context.get_action(action_name) + action_info = self._client.context.get_action(action_name) if action_info is None: raise UserError("Unknown action '{}'".format(action_name)) self._action_id = action_info.preamble.id @@ -579,10 +604,10 @@ def msg(self): return msg def _from_msg(self, msg): - assert(self._action_id == msg.action_id) + assert (self._action_id == msg.action_id) self._params.clear() for p in msg.params: - p_name = context.get_param_name(self.action_name, p.param_id) + p_name = self._client.context.get_param_name(self.action_name, p.param_id) self._param_values[p_name] = p def __str__(self): @@ -597,7 +622,8 @@ def set(self, **kwargs): class _EntityBase: - def __init__(self, entity_type, p4runtime_cls, modify_only=False): + def __init__(self, entity_type, p4runtime_cls, modify_only=False, client=None): + self._client = _get_client(client) self._init = False self._entity_type = entity_type self._entry = p4runtime_cls() @@ -611,7 +637,7 @@ def __dir__(self): d.extend(["insert", "modify", "delete"]) return d - # to be called before issueing a P4Runtime request + # to be called before issuing a P4Runtime request # enforces checks that cannot be performed when setting individual fields def _validate_msg(self): return True @@ -621,11 +647,11 @@ def _update_msg(self): def __str__(self): self._update_msg() - return str(_repr_pretty_p4runtime(self._entry)) + return str(_repr_pretty_p4runtime(self._entry, client=self._client)) def _repr_pretty_(self, p, cycle): self._update_msg() - p.text(_repr_pretty_p4runtime(self._entry)) + p.text(_repr_pretty_p4runtime(self._entry, client=self._client)) def __getattr__(self, name): raise AttributeError("'{}' object has no attribute '{}'".format( @@ -641,7 +667,7 @@ def _write(self, type_): update = p4runtime_pb2.Update() update.type = type_ getattr(update.entity, self._entity_type.name).CopyFrom(self._entry) - client.write_update(update) + self._client.write_update(update) def insert(self): if self._modify_only: @@ -669,7 +695,9 @@ def read(self, function=None): entity = p4runtime_pb2.Entity() getattr(entity, self._entity_type.name).CopyFrom(self._entry) - iterator = client.read_one(entity) + iterator = self._client.read_one(entity) + + client = self._client # Cannot use a (simpler) generator here as we need to decorate __next__ with # @parse_p4runtime_error. @@ -694,9 +722,9 @@ def __next__(self): return next(self) if isinstance(self._entity, _P4EntityBase): - e = type(self._entity)(self._entity.name) # create new instance of same entity + e = type(self._entity)(self._entity.name, client=client) # create new instance of same entity else: - e = type(self._entity)() + e = type(self._entity)(client=client) msg = getattr(entity, self._entity._entity_type.name) e._from_msg(msg) # neither of these should be needed @@ -712,13 +740,13 @@ def __next__(self): class _P4EntityBase(_EntityBase): - def __init__(self, p4_type, entity_type, p4runtime_cls, name=None, modify_only=False): - super().__init__(entity_type, p4runtime_cls, modify_only) + def __init__(self, p4_type, entity_type, p4runtime_cls, name=None, modify_only=False, client=None): + super().__init__(entity_type, p4runtime_cls, modify_only, client=client) self._p4_type = p4_type if name is None: raise UserError("Please provide name for {}".format(p4_type.pretty_name)) self.name = name - self._info = P4Objects(p4_type)[name] + self._info = P4Objects(p4_type, client=self._client)[name] self.id = self._info.id def __dir__(self): @@ -730,10 +758,10 @@ def info(self): class ActionProfileMember(_P4EntityBase): - def __init__(self, action_profile_name=None): + def __init__(self, action_profile_name=None, client=None): super().__init__( P4Type.action_profile, P4RuntimeEntity.action_profile_member, - p4runtime_pb2.ActionProfileMember, action_profile_name) + p4runtime_pb2.ActionProfileMember, action_profile_name, client=client) self.member_id = 0 self.action = None self._valid_action_ids = self._get_action_set() @@ -766,14 +794,14 @@ def __dir__(self): def _get_action_set(self): t_id = self._info.table_ids[0] - t_name = context.get_name_from_id(t_id) - t = context.get_table(t_name) + t_name = self._client.context.get_name_from_id(t_id) + t = self._client.context.get_table(t_name) return set([action.id for action in t.action_refs]) def __call__(self, **kwargs): for name, value in kwargs.items(): if name == "action" and type(value) is str: - value = Action(value) + value = Action(value, client=self._client) setattr(self, name, value) return self @@ -807,8 +835,8 @@ def _from_msg(self, msg): self.member_id = msg.member_id if msg.HasField('action'): action = msg.action - action_name = context.get_name_from_id(action.action_id) - self.action = Action(action_name) + action_name = self._client.context.get_name_from_id(action.action_id) + self.action = Action(action_name, client=self._client) self.action._from_msg(action) def read(self, function=None): @@ -820,7 +848,7 @@ def read(self, function=None): server. Otherwise, function is applied to all the members returned by the server. """ - return super().read(function) + return super().read(function=function) class GroupMember: @@ -829,6 +857,7 @@ class GroupMember: Construct with GroupMember(, weight=, watch=). You can set / get attributes member_id (required), weight (default 1), watch (default 0). """ + def __init__(self, member_id=None, weight=1, watch=0): if member_id is None: raise UserError("member_id is required") @@ -878,10 +907,10 @@ def _repr_pretty_(self, p, cycle): class ActionProfileGroup(_P4EntityBase): - def __init__(self, action_profile_name=None): + def __init__(self, action_profile_name=None, client=None): super().__init__( P4Type.action_profile, P4RuntimeEntity.action_profile_group, - p4runtime_pb2.ActionProfileGroup, action_profile_name) + p4runtime_pb2.ActionProfileGroup, action_profile_name, client=client) self.group_id = 0 self.max_size = 0 self.members = [] @@ -970,18 +999,18 @@ def read(self, function=None): return super().read(function) -def _get_action_profile(table_name): - table = context.get_table(table_name) +def _get_action_profile(table_name, client=None): + table = _get_client(client).context.get_table(table_name) implementation_id = table.implementation_id if implementation_id == 0: return None try: - implementation_name = context.get_name_from_id(implementation_id) + implementation_name = _get_client(client).context.get_name_from_id(implementation_id) except KeyError: raise InvalidP4InfoError( "Invalid implementation_id {} for table '{}'".format( implementation_id, table_name)) - ap = context.get_obj(P4Type.action_profile, implementation_name) + ap = _get_client(client).context.get_obj(P4Type.action_profile, implementation_name) if ap is None: raise InvalidP4InfoError("Unknown implementation for table '{}'".format(table_name)) return ap @@ -993,6 +1022,7 @@ class OneshotAction: Construct with OneshotAction(, weight=, watch=). You can set / get attributes action (required), weight (default 1), watch (default 0). """ + def __init__(self, action=None, weight=1, watch=0): if action is None: raise UserError("action is required") @@ -1033,14 +1063,15 @@ def _repr_pretty_(self, p, cycle): class Oneshot: - def __init__(self, table_name=None): + def __init__(self, table_name=None, client=None): + self._client = _get_client(client) self._init = False if table_name is None: raise UserError("Please provide table name") self.table_name = table_name self.actions = [] - self._table_info = P4Objects(P4Type.table)[table_name] - ap = _get_action_profile(table_name) + self._table_info = P4Objects(P4Type.table, client=self._client)[table_name] + ap = _get_action_profile(table_name, client=self._client) if not ap: raise UserError("Cannot create Oneshot instance for a direct table") if not ap.with_selector: @@ -1093,8 +1124,8 @@ def msg(self): def _from_msg(self, msg): for action in msg.action_profile_actions: - action_name = context.get_name_from_id(action.action.action_id) - a = Action(action_name) + action_name = self._client.context.get_name_from_id(action.action.action_id) + a = Action(action_name, client=self._client) a._from_msg(action.action) self.actions.append(OneshotAction(a, action.weight, action.watch)) @@ -1251,16 +1282,16 @@ def _action_spec_name_to_type(cls, name): "oneshot": cls._ActionSpecType.ONESHOT, }.get(name, None) - def __init__(self, table_name=None): + def __init__(self, table_name=None, client=None): super().__init__( P4Type.table, P4RuntimeEntity.table_entry, - p4runtime_pb2.TableEntry, table_name) + p4runtime_pb2.TableEntry, table_name, client=client) self.match = MatchKey(table_name, self._info.match_fields) self._action_spec_type = self._ActionSpecType.NONE self._action_spec = None self.priority = 0 self.is_default = False - ap = _get_action_profile(table_name) + ap = _get_action_profile(table_name, client=self._client) if ap is None: self._support_members = False self._support_groups = False @@ -1272,9 +1303,9 @@ def __init__(self, table_name=None): for res_id in self._info.direct_resource_ids: prefix = (res_id & 0xff000000) >> 24 if prefix == p4info_pb2.P4Ids.DIRECT_COUNTER: - self._direct_counter = context.get_obj_by_id(res_id) + self._direct_counter = self._client.context.get_obj_by_id(res_id) elif prefix == p4info_pb2.P4Ids.DIRECT_METER: - self._direct_meter = context.get_obj_by_id(res_id) + self._direct_meter = self._client.context.get_obj_by_id(res_id) self._counter_data = None self._meter_config = None self.__doc__ = """ @@ -1372,7 +1403,7 @@ def __dir__(self): def __call__(self, **kwargs): for name, value in kwargs.items(): if name == "action" and type(value) is str: - value = Action(value) + value = Action(value, client=self._client) setattr(self, name, value) return self @@ -1523,19 +1554,19 @@ def _from_msg(self, msg): self.priority = msg.priority self.is_default = msg.is_default_action for mf in msg.match: - mf_name = context.get_mf_name(self.name, mf.field_id) + mf_name = self._client.context.get_mf_name(self.name, mf.field_id) self.match._mk[mf_name] = mf if msg.action.HasField('action'): action = msg.action.action - action_name = context.get_name_from_id(action.action_id) - self.action = Action(action_name) + action_name = self._client.context.get_name_from_id(action.action_id) + self.action = Action(action_name, client=self._client) self.action._from_msg(action) elif msg.action.HasField('action_profile_member_id'): self.member_id = msg.action.action_profile_member_id elif msg.action.HasField('action_profile_group_id'): self.group_id = msg.action.action_profile_group_id elif msg.action.HasField('action_profile_action_set'): - self.oneshot = Oneshot(self.name) + self.oneshot = Oneshot(self.name, client=self._client) self.oneshot._from_msg(msg.action.action_profile_action_set) if msg.HasField('counter_data'): self._counter_data = _CounterData( @@ -1617,8 +1648,8 @@ def clear_meter_config(self): class _CounterEntryBase(_P4EntityBase): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, *args, client=None, **kwargs): + super().__init__(*args, **kwargs, client=client) self._counter_type = self._info.spec.unit self._data = None @@ -1679,11 +1710,13 @@ def clear_data(self): class CounterEntry(_CounterEntryBase): - def __init__(self, counter_name=None): + def __init__(self, counter_name=None, client=None): super().__init__( P4Type.counter, P4RuntimeEntity.counter_entry, p4runtime_pb2.CounterEntry, counter_name, - modify_only=True) + modify_only=True, + client=client + ) self._entry.counter_id = self.id self.__doc__ = """ An entry for counter '{}' @@ -1737,18 +1770,18 @@ def read(self, function=None): class DirectCounterEntry(_CounterEntryBase): - def __init__(self, direct_counter_name=None): + def __init__(self, direct_counter_name=None, client=None): super().__init__( P4Type.direct_counter, P4RuntimeEntity.direct_counter_entry, p4runtime_pb2.DirectCounterEntry, direct_counter_name, - modify_only=True) + modify_only=True, client=client) self._direct_table_id = self._info.direct_table_id try: - self._direct_table_name = context.get_name_from_id(self._direct_table_id) + self._direct_table_name = self._client.context.get_name_from_id(self._direct_table_id) except KeyError: raise InvalidP4InfoError("direct_table_id {} is not a valid table id".format( self._direct_table_id)) - self._table_entry = TableEntry(self._direct_table_name) + self._table_entry = TableEntry(self._direct_table_name, client=self._client) self.__doc__ = """ An entry for direct counter '{}' @@ -1774,7 +1807,7 @@ def __setattr__(self, name, value): raise UserError("Direct counters are not index-based") if name == "table_entry": if value is None: - self._table_entry = TableEntry(self._direct_table_name) + self._table_entry = TableEntry(self._direct_table_name, client=self._client) return if not isinstance(value, TableEntry): raise UserError("table_entry must be an instance of TableEntry") @@ -1824,8 +1857,8 @@ def read(self, function=None): class _MeterEntryBase(_P4EntityBase): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, *args, client=None, **kwargs): + super().__init__(*args, client=client, **kwargs) self._meter_type = self._info.spec.unit self._config = None @@ -1885,11 +1918,13 @@ def clear_config(self): class MeterEntry(_MeterEntryBase): - def __init__(self, meter_name=None): + def __init__(self, meter_name=None, client=None): super().__init__( P4Type.meter, P4RuntimeEntity.meter_entry, p4runtime_pb2.MeterEntry, meter_name, - modify_only=True) + modify_only=True, + client=client + ) self._entry.meter_id = self.id self.__doc__ = """ An entry for meter '{}' @@ -1947,18 +1982,20 @@ def read(self, function=None): class DirectMeterEntry(_MeterEntryBase): - def __init__(self, direct_meter_name=None): + def __init__(self, direct_meter_name=None, client=None): super().__init__( P4Type.direct_meter, P4RuntimeEntity.direct_meter_entry, p4runtime_pb2.DirectMeterEntry, direct_meter_name, - modify_only=True) + modify_only=True, + client=client + ) self._direct_table_id = self._info.direct_table_id try: - self._direct_table_name = context.get_name_from_id(self._direct_table_id) + self._direct_table_name = self._client.context.get_name_from_id(self._direct_table_id) except KeyError: raise InvalidP4InfoError("direct_table_id {} is not a valid table id".format( self._direct_table_id)) - self._table_entry = TableEntry(self._direct_table_name) + self._table_entry = TableEntry(self._direct_table_name, client=self._client) self.__doc__ = """ An entry for direct meter '{}' @@ -2038,9 +2075,10 @@ def read(self, function=None): class P4RuntimeEntityBuilder: - def __init__(self, obj_type, entity_type, entity_cls): + def __init__(self, obj_type, entity_type, entity_cls, client=None): + self._client = _get_client(client) self._obj_type = obj_type - self._names = sorted([name for name, _ in context.get_objs(obj_type)]) + self._names = sorted([name for name, _ in self._client.context.get_objs(obj_type)]) self._entity_type = entity_type self._entity_cls = entity_cls self.__doc__ = """Construct a {} entity @@ -2055,7 +2093,7 @@ def _ipython_key_completions_(self): return self._names def __getitem__(self, name): - obj = context.get_obj(self._obj_type, name) + obj = self._client.context.get_obj(self._obj_type, name) if obj is None: raise UserError("{} '{}' does not exist".format( self._obj_type.pretty_name, name)) @@ -2077,6 +2115,7 @@ class Replica: Construct with Replica(egress_port, instance=). You can set / get attributes egress_port (required), instance (default 0). """ + def __init__(self, egress_port=None, instance=0): if egress_port is None: raise UserError("egress_port is required") @@ -2118,10 +2157,12 @@ def _repr_pretty_(self, p, cycle): class MulticastGroupEntry(_EntityBase): - def __init__(self, group_id=0): + def __init__(self, group_id=0, client=None): super().__init__( P4RuntimeEntity.packet_replication_engine_entry, - p4runtime_pb2.PacketReplicationEngineEntry) + p4runtime_pb2.PacketReplicationEngineEntry, + client=client + ) self.group_id = group_id self.replicas = [] self.__doc__ = """ @@ -2186,10 +2227,12 @@ def add(self, egress_port=None, instance=0): class CloneSessionEntry(_EntityBase): - def __init__(self, session_id=0): + def __init__(self, session_id=0, client=None): super().__init__( P4RuntimeEntity.packet_replication_engine_entry, - p4runtime_pb2.PacketReplicationEngineEntry) + p4runtime_pb2.PacketReplicationEngineEntry, + client=client + ) self.session_id = session_id self.replicas = [] self.cos = 0 @@ -2263,7 +2306,7 @@ def add(self, egress_port=None, instance=0): return self -def Write(input_): +def Write(input_, client=None): """ Reads a WriteRequest from a file (text format) and sends it to the server. It rewrites the device id and election id appropriately. @@ -2272,19 +2315,19 @@ def Write(input_): if os.path.isfile(input_): with open(input_, 'r') as f: google.protobuf.text_format.Merge(f.read(), req) - client.write(req) + _get_client(client).write(req) else: raise UserError( "Write only works with files at the moment and '{}' is not a file".format( input_)) -def APIVersion(): +def APIVersion(client=None): """ Returns the version of the P4Runtime API implemented by the server, using the Capabilities RPC. """ - return client.api_version() + return _get_client(client).api_version() # see https://ipython.readthedocs.io/en/stable/config/details.html @@ -2340,8 +2383,7 @@ def pipe_config(arg): return parser -def setup(device_id=1, grpc_addr='localhost:50051', election_id=(1, 0), config=None): - global client +def setup(device_id=1, grpc_addr='localhost:50051', election_id=(1, 0), config=None, set_global_client=True): logging.debug("Creating P4Runtime client") client = P4RuntimeClient(device_id, grpc_addr, election_id) @@ -2377,14 +2419,21 @@ def setup(device_id=1, grpc_addr='localhost:50051', election_id=(1, 0), config=N sys.exit(1) logging.debug("Parsing P4Info message") - context.set_p4info(p4info) + client.context.set_p4info(p4info) + + if set_global_client: + global global_client + global_client = client + return client + +def teardown(client=None): + global global_client -def teardown(): - global client logging.debug("Tearing down P4Runtime client") - client.tear_down() - client = None + _get_client(client).tear_down() + if _get_client(client) == global_client: + global_client = None def main(): @@ -2393,7 +2442,7 @@ def main(): if args.verbose: logging.basicConfig(level=logging.DEBUG) - setup(args.device_id, args.grpc_addr, args.election_id, args.config) + setup(args.device_id, args.grpc_addr, args.election_id, args.config, set_global_client=True) c = Config() c.TerminalInteractiveShell.banner1 = '*** Welcome to the IPython shell for P4Runtime ***' @@ -2414,7 +2463,7 @@ def main(): "ActionProfileGroup": ActionProfileGroup, "OneshotAction": OneshotAction, "Oneshot": Oneshot, - "p4info": context.p4info, + "p4info": _get_client().context.p4info, "Write": Write, "Replica": Replica, "MulticastGroupEntry": MulticastGroupEntry, @@ -2442,7 +2491,7 @@ def main(): start_ipython(user_ns=user_ns, config=c, argv=[]) - client.tear_down() + _get_client().tear_down() if __name__ == '__main__': # pragma: no cover diff --git a/p4runtime_sh/test.py b/p4runtime_sh/test.py index 17b3920..bf2802c 100644 --- a/p4runtime_sh/test.py +++ b/p4runtime_sh/test.py @@ -91,7 +91,9 @@ def setUp(self): super().setUp() self.serve() - def run_sh(self, args=[]): + def run_sh(self, args=None): + if args is None: + args = [] new_args = ["p4runtime-sh", "--grpc-addr", self.grpc_addr] + args rc = 0 stdout = None @@ -290,7 +292,7 @@ def test_table_entry_lpm(self, input_, value, length): te.insert() # Cannot use format here because it would require escaping all braces, - # which would make wiriting tests much more annoying + # which would make writing tests much more annoying expected_entry = """ table_id: 33567650 match {