Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: drop mediator_terms and recipient_terms #2515

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions aries_cloudagent/messaging/models/base_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,11 @@ class Meta:

def __init__(
self,
id: str = None,
state: str = None,
id: Optional[str] = None,
state: Optional[str] = None,
*,
created_at: Union[str, datetime] = None,
updated_at: Union[str, datetime] = None,
created_at: Union[str, datetime, None] = None,
updated_at: Union[str, datetime, None] = None,
new_with_id: bool = False,
):
"""Initialize a new BaseRecord."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,7 @@ async def test_create_request_multitenant(self):

async def test_create_request_mediation_id(self):
mediation_record = MediationRecord(
mediation_id="test_medation_id",
role=MediationRecord.ROLE_CLIENT,
state=MediationRecord.STATE_GRANTED,
connection_id=self.test_mediator_conn_id,
Expand Down Expand Up @@ -866,6 +867,7 @@ async def test_create_response_multitenant(self):
)

mediation_record = MediationRecord(
mediation_id="test_medation_id",
role=MediationRecord.ROLE_CLIENT,
state=MediationRecord.STATE_GRANTED,
connection_id=self.test_mediator_conn_id,
Expand Down Expand Up @@ -936,6 +938,7 @@ async def test_create_response_bad_state(self):

async def test_create_response_mediation(self):
mediation_record = MediationRecord(
mediation_id="test_medation_id",
role=MediationRecord.ROLE_CLIENT,
state=MediationRecord.STATE_GRANTED,
connection_id=self.test_mediator_conn_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
from ..mediation_deny_handler import MediationDenyHandler

TEST_CONN_ID = "conn-id"
TEST_MEDIATOR_TERMS = ["test", "mediator", "terms"]
TEST_RECIPIENT_TERMS = ["test", "recipient", "terms"]


class TestMediationDenyHandler(AsyncTestCase):
Expand All @@ -22,9 +20,7 @@ async def setUp(self):
"""Setup test dependencies."""
self.context = RequestContext.test_context()
self.session = await self.context.session()
self.context.message = MediationDeny(
mediator_terms=TEST_MEDIATOR_TERMS, recipient_terms=TEST_RECIPIENT_TERMS
)
self.context.message = MediationDeny()
self.context.connection_ready = True
self.context.connection_record = ConnRecord(connection_id=TEST_CONN_ID)

Expand All @@ -50,5 +46,3 @@ async def test_handler(self):
)
assert record
assert record.state == MediationRecord.STATE_DENIED
assert record.mediator_terms == TEST_MEDIATOR_TERMS
assert record.recipient_terms == TEST_RECIPIENT_TERMS
28 changes: 2 additions & 26 deletions aries_cloudagent/protocols/coordinate_mediation/v1_0/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,8 @@ async def receive_request(
"MediationRecord already exists for connection"
)

# TODO: Determine if terms are acceptable
record = MediationRecord(
connection_id=connection_id,
mediator_terms=request.mediator_terms,
recipient_terms=request.recipient_terms,
)
await record.save(session, reason="New mediation request received")
return record
Expand Down Expand Up @@ -186,19 +183,11 @@ async def grant_request(
async def deny_request(
self,
mediation_id: str,
*,
mediator_terms: Sequence[str] = None,
recipient_terms: Sequence[str] = None,
) -> Tuple[MediationRecord, MediationDeny]:
"""Deny a mediation request and prepare a deny message.

Args:
mediation_id: mediation record ID to deny
mediator_terms (Sequence[str]): updated mediator terms to return to
requester.
recipient_terms (Sequence[str]): updated recipient terms to return to
requester.

Returns:
MediationDeny: message to return to denied client.

Expand All @@ -215,9 +204,7 @@ async def deny_request(
mediation_record.state = MediationRecord.STATE_DENIED
await mediation_record.save(session, reason="Mediation request denied")

deny = MediationDeny(
mediator_terms=mediator_terms, recipient_terms=recipient_terms
)
deny = MediationDeny()
return mediation_record, deny

async def _handle_keylist_update_add(
Expand Down Expand Up @@ -442,15 +429,11 @@ async def clear_default_mediator(self):
async def prepare_request(
self,
connection_id: str,
mediator_terms: Sequence[str] = None,
recipient_terms: Sequence[str] = None,
) -> Tuple[MediationRecord, MediationRequest]:
"""Prepare a MediationRequest Message, saving a new mediation record.

Args:
connection_id (str): ID representing mediator
mediator_terms (Sequence[str]): mediator_terms
recipient_terms (Sequence[str]): recipient_terms

Returns:
MediationRequest: message to send to mediator
Expand All @@ -459,15 +442,11 @@ async def prepare_request(
record = MediationRecord(
role=MediationRecord.ROLE_CLIENT,
connection_id=connection_id,
mediator_terms=mediator_terms,
recipient_terms=recipient_terms,
)

async with self._profile.session() as session:
await record.save(session, reason="Creating new mediation request.")
request = MediationRequest(
mediator_terms=mediator_terms, recipient_terms=recipient_terms
)
request = MediationRequest()
return record, request

async def request_granted(self, record: MediationRecord, grant: MediationGrant):
Expand Down Expand Up @@ -495,9 +474,6 @@ async def request_denied(self, record: MediationRecord, deny: MediationDeny):

"""
record.state = MediationRecord.STATE_DENIED
# TODO Record terms elsewhere?
record.mediator_terms = deny.mediator_terms
record.recipient_terms = deny.recipient_terms
async with self._profile.session() as session:
await record.save(session, reason="Mediation request denied.")

Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
"""mediate-deny message used to notify mediation client of a denied mediation request."""

from typing import Sequence

from marshmallow import fields

from .....messaging.agent_message import AgentMessage, AgentMessageSchema
from ..message_types import MEDIATE_DENY, PROTOCOL_PACKAGE
Expand All @@ -24,20 +21,10 @@ class Meta:

def __init__(
self,
*,
mediator_terms: Sequence[str] = None,
recipient_terms: Sequence[str] = None,
**kwargs,
):
"""Initialize mediation deny object.

Args:
mediator_terms: Terms that were agreed by the recipient
recipient_terms: Terms that recipient wants to mediator to agree to
"""
"""Initialize mediation deny object."""
super(MediationDeny, self).__init__(**kwargs)
self.mediator_terms = list(mediator_terms) if mediator_terms else []
self.recipient_terms = list(recipient_terms) if recipient_terms else []


class MediationDenySchema(AgentMessageSchema):
Expand All @@ -47,12 +34,3 @@ class Meta:
"""Mediation deny schema metadata."""

model_class = MediationDeny

mediator_terms = fields.List(
fields.Str(metadata={"description": "Terms for mediator to agree"}),
required=False,
)
recipient_terms = fields.List(
fields.Str(metadata={"description": "Terms for recipient to agree"}),
required=False,
)
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Used to notify mediation client of a granted mediation request.
"""

from typing import Sequence
from typing import Optional, Sequence

from marshmallow import fields

Expand All @@ -29,8 +29,8 @@ class Meta:
def __init__(
self,
*,
endpoint: str = None,
routing_keys: Sequence[str] = None,
endpoint: Optional[str] = None,
routing_keys: Optional[Sequence[str]] = None,
**kwargs,
):
"""Initialize mediation grant object.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
"""mediate-request message used to request mediation from a mediator."""

from typing import Sequence

from marshmallow import fields

from .....messaging.agent_message import AgentMessage, AgentMessageSchema
from ..message_types import MEDIATE_REQUEST, PROTOCOL_PACKAGE
Expand All @@ -22,22 +19,9 @@ class Meta:
message_type = MEDIATE_REQUEST
schema_class = "MediationRequestSchema"

def __init__(
self,
*,
mediator_terms: Sequence[str] = None,
recipient_terms: Sequence[str] = None,
**kwargs,
):
"""Initialize mediation request object.

Args:
mediator_terms: Mediator's terms for granting mediation.
recipient_terms: Recipient's proposed mediation terms.
"""
def __init__(self, **kwargs):
"""Initialize mediation request object."""
super(MediationRequest, self).__init__(**kwargs)
self.mediator_terms = list(mediator_terms) if mediator_terms else []
self.recipient_terms = list(recipient_terms) if recipient_terms else []


class MediationRequestSchema(AgentMessageSchema):
Expand All @@ -47,27 +31,3 @@ class Meta:
"""Mediation request schema metadata."""

model_class = MediationRequest

mediator_terms = fields.List(
fields.Str(
metadata={
"description": (
"Indicate terms that the mediator requires the recipient to"
" agree to"
)
}
),
required=False,
metadata={"description": "List of mediator rules for recipient"},
)
recipient_terms = fields.List(
fields.Str(
metadata={
"description": (
"Indicate terms that the recipient requires the mediator to"
" agree to"
)
}
),
required=False,
)
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ class TestMediateDeny(MessageTest, TestCase):
TYPE = MEDIATE_DENY
CLASS = MediationDeny
SCHEMA = MediationDenySchema
VALUES = {"mediator_terms": ["test", "terms"], "recipient_terms": ["test", "terms"]}
VALUES = {}
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ class TestMediateRequest(MessageTest, TestCase):
TYPE = MEDIATE_REQUEST
CLASS = MediationRequest
SCHEMA = MediationRequestSchema
VALUES = {"mediator_terms": ["test", "terms"], "recipient_terms": ["test", "terms"]}
VALUES = {}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Store state for Mediation requests."""

from typing import Sequence
from typing import Optional, Sequence

from marshmallow import EXCLUDE, fields

Expand Down Expand Up @@ -33,14 +33,15 @@ class Meta:
def __init__(
self,
*,
mediation_id: str = None,
state: str = None,
role: str = None,
connection_id: str = None,
mediator_terms: Sequence[str] = None,
recipient_terms: Sequence[str] = None,
routing_keys: Sequence[str] = None,
endpoint: str = None,
mediation_id: Optional[str] = None,
state: Optional[str] = None,
role: Optional[str] = None,
connection_id: Optional[str] = None,
routing_keys: Optional[Sequence[str]] = None,
endpoint: Optional[str] = None,
# Included for record backwards compat
mediator_terms: Optional[Sequence[str]] = None,
recipient_terms: Optional[Sequence[str]] = None,
**kwargs,
):
"""__init__.
Expand All @@ -50,8 +51,6 @@ def __init__(
state (str): state, defaults to 'request_received'
role (str): role in mediation, defaults to 'server'
connection_id (str): ID of connection requesting or managing mediation
mediator_terms (Sequence[str]): mediator_terms
recipient_terms (Sequence[str]): recipient_terms
routing_keys (Sequence[str]): keys in mediator control used to
receive incoming messages
endpoint (str): mediators endpoint
Expand All @@ -61,8 +60,6 @@ def __init__(
super().__init__(mediation_id, state or self.STATE_REQUEST, **kwargs)
self.role = role if role else self.ROLE_SERVER
self.connection_id = connection_id
self.mediator_terms = list(mediator_terms) if mediator_terms else []
self.recipient_terms = list(recipient_terms) if recipient_terms else []
self.routing_keys = list(routing_keys) if routing_keys else []
self.endpoint = endpoint

Expand All @@ -79,6 +76,8 @@ def __eq__(self, other: "MediationRecord"):
@property
def mediation_id(self) -> str:
"""Get Mediation ID."""
if not self._id:
raise ValueError("Record not yet stored")
return self._id

@property
Expand Down Expand Up @@ -109,8 +108,6 @@ def record_value(self) -> dict:
return {
prop: getattr(self, prop)
for prop in (
"mediator_terms",
"recipient_terms",
"routing_keys",
"endpoint",
)
Expand Down Expand Up @@ -170,10 +167,12 @@ class Meta:
mediation_id = fields.Str(required=False)
role = fields.Str(required=True)
connection_id = fields.Str(required=True)
mediator_terms = fields.List(fields.Str(), required=False)
recipient_terms = fields.List(fields.Str(), required=False)
routing_keys = fields.List(
fields.Str(validate=DID_KEY_VALIDATE, metadata={"example": DID_KEY_EXAMPLE}),
required=False,
)
endpoint = fields.Str(required=False)

# Included for backwards compat with old records
mediator_terms = fields.List(fields.Str(), required=False)
recipient_terms = fields.List(fields.Str(), required=False)
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Tests for models."""
Loading