-
Notifications
You must be signed in to change notification settings - Fork 0
/
dispatcher.py
121 lines (102 loc) · 3.67 KB
/
dispatcher.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import os, struct, sys
from framing import AMQP_FRAME, Frame, encode, decode
from util import Buffer, parse
PROTO_HDR_FMT = "!4sBBBB"
PROTO_HDR_SIZE = struct.calcsize(PROTO_HDR_FMT)
assert PROTO_HDR_SIZE == 8
class Dispatcher:
def __init__(self, protocol_id, frame_type):
self.protocol_id = protocol_id
self.frame_type = frame_type
self.id = "%X" % id(self)
self._tracing = set()
self.tracing(*os.environ.get("AMQP_TRACE", "").split())
self.multiline = False
self.input = Buffer()
self.output = Buffer(struct.pack(PROTO_HDR_FMT, "AMQP", self.protocol_id, 1, 0, 0))
self.state = self.__proto_header
def tracing(self, *args, **kwargs):
names = set(args)
for n in kwargs:
if kwargs[n]: names.add(n)
if "err" not in kwargs:
names.add("err")
self._tracing = names
def trace(self, categories, format, *args):
if isinstance(categories, basestring):
categories = (categories,)
for category in categories:
if category in self._tracing:
prefix = "[%s %s]" % (self.id, category)
if args:
message = format % args
else:
message = format
print >> sys.stderr, prefix, \
message.replace(os.linesep, "%s%s " % (os.linesep, prefix))
break
def write(self, bytes):
self.trace("raw", "RECV: %r", bytes)
self.input.write(bytes)
self.state = parse(self.state)
def closed(self):
self.trace(("raw", "frm"), "CLOSED")
def error(self, exc):
pass
def __proto_header(self):
if self.input.pending() >= PROTO_HDR_SIZE:
hdr = self.input.read(PROTO_HDR_SIZE)
magic, proto, major, minor, revision = struct.unpack(PROTO_HDR_FMT, hdr)
if (magic, proto, major, minor, revision) == ("AMQP", self.protocol_id, 1, 0, 0):
return self.__framing
else:
raise ValueError("bad protocol header")
def __framing(self):
while True:
f, n = decode(self.input.peek())
if f:
self.input.read(n)
state = self.process_frame(f)
if state is not None:
return state
else:
break
def process_frame(self, f):
body, remainder = self.type_decoder.decode(f.payload)
body.payload = remainder
self.trace("frm", "RECV[%s]: %s", f.channel, body.format(self.multiline))
return getattr(self, "do_%s" % body.NAME, self.unhandled)(f.channel, body)
def post_frame(self, channel, body):
self.trace("frm", "SENT[%s]: %s", channel, body.format(self.multiline))
encoded = self.type_encoder.encode(body)
if body.payload: encoded += body.payload
f = Frame(self.frame_type, channel, None, encoded)
self.output.write(encode(f))
def read(self, n=None):
self.tick()
result = self.output.read(n)
self.trace("raw", "SENT: %r", result)
return result
def peek(self, n=None):
return self.output.peek(n)
def pending(self):
self.tick()
return self.output.pending()