Skip to content

Commit

Permalink
finish subscription websocket support via flask_sock
Browse files Browse the repository at this point in the history
spec: https://atproto.com/specs/event-stream

subscription handler functions are generators that yield (dict header, dict payload) tuples. flask_server uses flask-sock to serve these over websockets, DAG-CBOR encoding each header/payload pair and concatenating them into a single message.
  • Loading branch information
snarfed committed Aug 17, 2023
1 parent 19cd0ef commit 7e34048
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 20 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ You can also register a method handler with [`Server.register`](https://lexrpc.r
server.register('com.example.my-query', my_query_handler)
```

[Event stream methods with type `subscription`](https://atproto.com/specs/event-stream) should be generators that `yield` messages to send to the client. They take parameters as kwargs, but no positional `input`.
[Event stream methods with type `subscription`](https://atproto.com/specs/event-stream) should be generators that `yield` frames to send to the client. [Each frame is a `(header dict, payload dict)` tuple](https://atproto.com/specs/event-stream#framing) that will be DAG-CBOR encoded and sent to the websocket client. Subscription methods take parameters as kwargs, but no positional `input`.

```
@server.method('com.example.count')
Expand Down
2 changes: 1 addition & 1 deletion lexrpc/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def decode_params(self, method_nsid, params):
try:
if type == 'number':
decoded[name] = float(val)
elif type == 'int':
elif type in ('int', 'integer'):
decoded[name] = int(val)
except ValueError as e:
e.args = [f'{e.args[0]} for {type} parameter {name}']
Expand Down
29 changes: 26 additions & 3 deletions lexrpc/flask_server.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""Flask handler for /xrpc/... endpoint."""
import logging

import dag_cbor
from flask import request
from flask.json import jsonify
from flask.views import View
from flask_sock import Sock
from jsonschema import ValidationError
from simple_websocket import ConnectionClosed

from .base import NSID_RE

Expand All @@ -29,17 +31,18 @@ def init_flask(xrpc_server, app):
logger.info(f'Registering {xrpc_server} with {app}')

sock = Sock(app)
for nsid, fn in xrpc_server._methods.items():
for nsid, _ in xrpc_server._methods.items():
if xrpc_server._defs[nsid]['type'] == 'subscription':
sock.route(f'/xrpc/{nsid}')(fn)
sock.route(f'/xrpc/{nsid}')(
lambda ws: subscription(ws, xrpc_server, nsid))

app.add_url_rule('/xrpc/<nsid>',
view_func=XrpcEndpoint.as_view('xrpc-endpoint', xrpc_server),
methods=['GET', 'POST', 'OPTIONS'])


class XrpcEndpoint(View):
"""Handles inbound XRPC requests.
"""Handles inbound XRPC query and procedure (but not subscription) methods.
Attributes:
server: :class:`lexrpc.Server`
Expand Down Expand Up @@ -93,3 +96,23 @@ def dispatch_request(self, nsid):
if not isinstance(output, (str, bytes)):
return {'message': f'Expected str or bytes output to match {out_encoding}, got {output.__class__}'}, 500
return output, RESPONSE_HEADERS


def subscription(ws, xrpc_server, nsid):
"""Handles inbound XRPC subscription methods over websocket.
Args:
ws: :class:`simple_websocket.ws.WSConnection`
xrpc_server: :class:`lexrpc.Server`
nsid: str, XRPC method NSID
"""
logger.debug(f'New websocket client for {nsid}')
params = xrpc_server.decode_params(nsid, request.args.items(multi=True))

try:
for header, payload in xrpc_server.call(nsid, **params):
# TODO: validate header, payload?
logger.debug(f'Sending to {nsid} websocket client: {header} {payload}')
ws.send(dag_cbor.encode(header) + dag_cbor.encode(payload))
except ConnectionClosed as cc:
logger.debug(f'Websocket client disconnected from {nsid}: {cc}')
4 changes: 3 additions & 1 deletion lexrpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ def register(self, nsid, fn):
def call(self, nsid, input=None, **params):
"""Calls an XRPC query or procedure method.
For subscriptions, returns a generator that yields (header dict, payload
dict) tuples to be DAG-CBOR encoded and sent to the websocket client.
Args:
nsid: str, method NSID
input: dict or bytes, input body, optional for subscriptions
Expand Down Expand Up @@ -94,6 +97,5 @@ def loggable(val):

if not subscription:
self._maybe_validate(nsid, 'output', output)
# TODO: validate subscription yielded items against message

return output
60 changes: 50 additions & 10 deletions lexrpc/tests/test_flask_server.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,28 @@
"""Unit tests for flask_server.py."""
from threading import Thread
from unittest import skip, TestCase

import dag_cbor
from flask import Flask
from simple_websocket import ConnectionClosed

from ..flask_server import init_flask
from ..flask_server import init_flask, subscription
from .lexicons import LEXICONS
from .test_server import server


class FakeConnection:
"""Fake of :class:`simple_websocket.ws.WSConnection`."""
exc = None
sent = []

@classmethod
def send(cls, msg):
if cls.exc:
raise cls.exc
cls.sent.append(msg)


class XrpcEndpointTest(TestCase):
maxDiff = None

Expand All @@ -18,6 +33,8 @@ def setUpClass(cls):

def setUp(self):
self.client = self.app.test_client()
FakeConnection.exc = None
FakeConnection.sent = []

def test_procedure(self):
input = {
Expand Down Expand Up @@ -60,16 +77,38 @@ def test_query_boolean_param(self):

resp = self.client.get('/xrpc/io.example.query?z=foolz')
self.assertEqual(400, resp.status_code)
self.assertEqual("Got 'foolz' for boolean parameter z, expected true or false",
resp.json['message'])
self.assertEqual(
"Got 'foolz' for boolean parameter z, expected true or false",
resp.json['message'])

# TODO
# needs websocket test client, but flask-sock doesn't have one yet
# https://github.com/miguelgrinberg/flask-sock/issues/23
@skip
def test_subscription(self):
resp = self.client.get('/xrpc/io.example.subscribe?start=3&end=6')
self.assertEqual(200, resp.status_code, resp.json)
def subscribe():
with self.app.test_request_context(query_string={'start': 3, 'end': 6}):
subscription(FakeConnection, server, 'io.example.subscribe')

subscriber = Thread(target=subscribe)
subscriber.start()
subscriber.join()

header_bytes = dag_cbor.encode({'hea': 'der'})
self.assertEqual([
header_bytes + dag_cbor.encode({'num': 3}),
header_bytes + dag_cbor.encode({'num': 4}),
header_bytes + dag_cbor.encode({'num': 5}),
], FakeConnection.sent)

def test_subscription_client_disconnects(self):
FakeConnection.exc = ConnectionClosed()

def subscribe():
with self.app.test_request_context(query_string={'start': 3, 'end': 6}):
subscription(FakeConnection, server, 'io.example.subscribe')


subscriber = Thread(target=subscribe)
subscriber.start()
subscriber.join()
self.assertEqual([], FakeConnection.sent)

# TODO
@skip
Expand All @@ -87,7 +126,8 @@ def test_procedure_missing_input(self):
# TODO
@skip
def test_procedure_bad_input(self):
resp = self.client.post('/xrpc/io.example.procedure', json={'foo': 2, 'bar': 3})
resp = self.client.post('/xrpc/io.example.procedure',
json={'foo': 2, 'bar': 3})
self.assertEqual(400, resp.status_code)
self.assertTrue(resp.json['message'].startswith(
'Error validating io.example.procedure input:'))
Expand Down
8 changes: 4 additions & 4 deletions lexrpc/tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def encodings(input, **params):
@server.method('io.example.subscribe')
def subscribe(start=None, end=None):
for num in range(start, end):
yield {'num': num}
yield {'hea': 'der'}, {'num': num}


class ServerTest(TestCase):
Expand Down Expand Up @@ -144,9 +144,9 @@ def test_array(self):
def test_subscription(self):
gen = server.call('io.example.subscribe', start=3, end=6)
self.assertEqual([
{'num': 3},
{'num': 4},
{'num': 5},
({'hea': 'der'}, {'num': 3}),
({'hea': 'der'}, {'num': 4}),
({'hea': 'der'}, {'num': 5}),
], list(gen))

def test_unknown_methods(self):
Expand Down

0 comments on commit 7e34048

Please sign in to comment.