diff --git a/.gitignore b/.gitignore index a76001e..5c026fa 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ __pycache__ fastbencode.egg-info *.pyc dist +*~ diff --git a/README.md b/README.md index 84f0f78..5545d84 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,12 @@ Example: >>> bdecode(bencode([1, 2, b'a', {b'd': 3}])) [1, 2, b'a', {b'd': 3}] +The default ``bencode``/``bdecode`` functions just operate on +bytestrings. Use ``bencode_utf8`` / ``bdecode_utf8`` to +serialize/deserialize all plain strings as UTF-8 bytestrings. +Note that for performance reasons, all dictionary keys still have to be +bytestrings. + License ======= fastbencode is available under the GNU GPL, version 2 or later. diff --git a/fastbencode/_bencode_py.py b/fastbencode/_bencode_py.py index f04e182..6d33f2e 100644 --- a/fastbencode/_bencode_py.py +++ b/fastbencode/_bencode_py.py @@ -21,27 +21,28 @@ class BDecoder: - def __init__(self, yield_tuples=False) -> None: + def __init__(self, yield_tuples=False, bytestring_encoding=None) -> None: """Constructor. :param yield_tuples: if true, decode "l" elements as tuples rather than lists. """ self.yield_tuples = yield_tuples + self.bytestring_encoding = bytestring_encoding decode_func = {} decode_func[b'l'] = self.decode_list decode_func[b'd'] = self.decode_dict decode_func[b'i'] = self.decode_int - decode_func[b'0'] = self.decode_string - decode_func[b'1'] = self.decode_string - decode_func[b'2'] = self.decode_string - decode_func[b'3'] = self.decode_string - decode_func[b'4'] = self.decode_string - decode_func[b'5'] = self.decode_string - decode_func[b'6'] = self.decode_string - decode_func[b'7'] = self.decode_string - decode_func[b'8'] = self.decode_string - decode_func[b'9'] = self.decode_string + decode_func[b'0'] = self.decode_bytes + decode_func[b'1'] = self.decode_bytes + decode_func[b'2'] = self.decode_bytes + decode_func[b'3'] = self.decode_bytes + decode_func[b'4'] = self.decode_bytes + decode_func[b'5'] = self.decode_bytes + decode_func[b'6'] = self.decode_bytes + decode_func[b'7'] = self.decode_bytes + decode_func[b'8'] = self.decode_bytes + decode_func[b'9'] = self.decode_bytes self.decode_func = decode_func def decode_int(self, x, f): @@ -54,13 +55,16 @@ def decode_int(self, x, f): raise ValueError return (n, newf + 1) - def decode_string(self, x, f): + def decode_bytes(self, x, f): colon = x.index(b':', f) n = int(x[f:colon]) if x[f:f + 1] == b'0' and colon != f + 1: raise ValueError colon += 1 - return (x[colon:colon + n], colon + n) + d = x[colon:colon + n] + if self.bytestring_encoding: + d = d.decode(self.bytestring_encoding) + return (d, colon + n) def decode_list(self, x, f): r, f = [], f + 1 @@ -75,7 +79,7 @@ def decode_dict(self, x, f): r, f = {}, f + 1 lastkey = None while x[f:f + 1] != b'e': - k, f = self.decode_string(x, f) + k, f = self.decode_bytes(x, f) if lastkey is not None and lastkey >= k: raise ValueError lastkey = k @@ -100,6 +104,9 @@ def bdecode(self, x): _tuple_decoder = BDecoder(True) bdecode_as_tuple = _tuple_decoder.bdecode +_utf8_decoder = BDecoder(bytestring_encoding='utf-8') +bdecode_utf8 = _utf8_decoder.bdecode + class Bencached: __slots__ = ['bencoded'] @@ -108,55 +115,72 @@ def __init__(self, s) -> None: self.bencoded = s -def encode_bencached(x, r): - r.append(x.bencoded) +class BEncoder: + + def __init__(self, bytestring_encoding=None): + self.bytestring_encoding = bytestring_encoding + self.encode_func: Dict[Type, Callable[[object, List[bytes]], None]] = { + Bencached: self.encode_bencached, + int: self.encode_int, + bytes: self.encode_bytes, + list: self.encode_list, + tuple: self.encode_list, + dict: self.encode_dict, + bool: self.encode_bool, + str: self.encode_str, + } + def encode_bencached(self, x, r): + r.append(x.bencoded) -def encode_bool(x, r): - encode_int(int(x), r) + def encode_bool(self, x, r): + self.encode_int(int(x), r) -def encode_int(x, r): - r.extend((b'i', int_to_bytes(x), b'e')) + def encode_int(self, x, r): + r.extend((b'i', int_to_bytes(x), b'e')) -def encode_string(x, r): - r.extend((int_to_bytes(len(x)), b':', x)) + def encode_bytes(self, x, r): + r.extend((int_to_bytes(len(x)), b':', x)) -def encode_list(x, r): - r.append(b'l') - for i in x: - encode_func[type(i)](i, r) - r.append(b'e') + def encode_list(self, x, r): + r.append(b'l') + for i in x: + self.encode(i, r) + r.append(b'e') -def encode_dict(x, r): - r.append(b'd') - ilist = sorted(x.items()) - for k, v in ilist: - r.extend((int_to_bytes(len(k)), b':', k)) - encode_func[type(v)](v, r) - r.append(b'e') + def encode_dict(self, x, r): + r.append(b'd') + ilist = sorted(x.items()) + for k, v in ilist: + r.extend((int_to_bytes(len(k)), b':', k)) + self.encode(v, r) + r.append(b'e') + def encode_str(self, x, r): + if self.bytestring_encoding is None: + raise TypeError("string found but no encoding specified. " + "Use bencode_utf8 rather bencode?") + return self.encode_bytes(x.encode(self.bytestring_encoding), r) -encode_func: Dict[Type, Callable[[object, List[bytes]], None]] = {} -encode_func[type(Bencached(0))] = encode_bencached -encode_func[int] = encode_int + def encode(self, x, r): + self.encode_func[type(x)](x, r) def int_to_bytes(n): return b'%d' % n - -encode_func[bytes] = encode_string -encode_func[list] = encode_list -encode_func[tuple] = encode_list -encode_func[dict] = encode_dict -encode_func[bool] = encode_bool - - def bencode(x): r = [] - encode_func[type(x)](x, r) + encoder = BEncoder() + encoder.encode(x, r) + return b''.join(r) + +def bencode_utf8(x): + r = [] + encoder = BEncoder(bytestring_encoding='utf-8') + encoder.encode(x, r) return b''.join(r) diff --git a/fastbencode/_bencode_pyx.pyx b/fastbencode/_bencode_pyx.pyx index 8561e69..32a9ef3 100644 --- a/fastbencode/_bencode_pyx.pyx +++ b/fastbencode/_bencode_pyx.pyx @@ -46,15 +46,22 @@ from cpython.mem cimport ( PyMem_Malloc, PyMem_Realloc, ) +from cpython.unicode cimport ( + PyUnicode_FromEncodedObject, + PyUnicode_FromStringAndSize, + PyUnicode_Check, + ) from cpython.tuple cimport ( PyTuple_CheckExact, ) from libc.stdlib cimport ( strtol, + free, ) from libc.string cimport ( memcpy, + strdup, ) cdef extern from "python-compat.h": @@ -79,9 +86,10 @@ cdef class Decoder: cdef readonly char *tail cdef readonly int size cdef readonly int _yield_tuples + cdef readonly char *_bytestring_encoding cdef object text - def __init__(self, s, yield_tuples=0): + def __init__(self, s, yield_tuples=0, str bytestring_encoding=None): """Initialize decoder engine. @param s: Python string. """ @@ -92,6 +100,13 @@ cdef class Decoder: self.tail = PyBytes_AS_STRING(s) self.size = PyBytes_GET_SIZE(s) self._yield_tuples = int(yield_tuples) + if bytestring_encoding is None: + self._bytestring_encoding = NULL + else: + self._bytestring_encoding = strdup(bytestring_encoding.encode('utf-8')) + + def __dealloc__(self): + free(self._bytestring_encoding) def decode(self): result = self._decode_object() @@ -112,7 +127,7 @@ cdef class Decoder: try: ch = self.tail[0] if c'0' <= ch <= c'9': - return self._decode_string() + return self._decode_bytes() elif ch == c'l': D_UPDATE_TAIL(self, 1) return self._decode_list() @@ -155,12 +170,12 @@ cdef class Decoder: D_UPDATE_TAIL(self, i+1) return ret - cdef object _decode_string(self): + cdef object _decode_bytes(self): cdef int n cdef char *next_tail # strtol allows leading whitespace, negatives, and leading zeros # however, all callers have already checked that '0' <= tail[0] <= '9' - # or they wouldn't have called _decode_string + # or they wouldn't have called _decode_bytes # strtol will stop at trailing whitespace, etc n = strtol(self.tail, &next_tail, 10) if next_tail == NULL or next_tail[0] != c':': @@ -171,13 +186,22 @@ cdef class Decoder: raise ValueError('leading zeros are not allowed') D_UPDATE_TAIL(self, next_tail - self.tail + 1) if n == 0: - return b'' + if self._bytestring_encoding == NULL: + return b'' + else: + return '' if n > self.size: raise ValueError('stream underflow') if n < 0: raise ValueError('string size below zero: %d' % n) - result = PyBytes_FromStringAndSize(self.tail, n) + if self._bytestring_encoding == NULL: + result = PyBytes_FromStringAndSize(self.tail, n) + elif self._bytestring_encoding == b'utf-8': + result = PyUnicode_FromStringAndSize(self.tail, n) + else: + result = PyBytes_FromStringAndSize(self.tail, n) + result = PyUnicode_FromEncodedObject(result, self._bytestring_encoding, NULL) D_UPDATE_TAIL(self, n) return result @@ -214,7 +238,7 @@ cdef class Decoder: # keys should be strings only if self.tail[0] < c'0' or self.tail[0] > c'9': raise ValueError('key was not a simple string.') - key = self._decode_string() + key = self._decode_bytes() if lastkey is not None and lastkey >= key: raise ValueError('dict keys disordered') else: @@ -235,6 +259,11 @@ def bdecode_as_tuple(object s): return Decoder(s, True).decode() +def bdecode_utf8(object s): + """Decode string x to Python object, decoding bytestrings as UTF8 strings.""" + return Decoder(s, bytestring_encoding='utf-8').decode() + + class Bencached(object): __slots__ = ['bencoded'] @@ -254,8 +283,9 @@ cdef class Encoder: cdef readonly int size cdef readonly char *buffer cdef readonly int maxsize + cdef readonly object _bytestring_encoding - def __init__(self, int maxsize=INITSIZE): + def __init__(self, int maxsize=INITSIZE, str bytestring_encoding=None): """Initialize encoder engine @param maxsize: initial size of internal char buffer """ @@ -273,6 +303,8 @@ cdef class Encoder: self.maxsize = maxsize self.tail = p + self._bytestring_encoding = bytestring_encoding + def __dealloc__(self): PyMem_Free(self.buffer) self.buffer = NULL @@ -329,7 +361,7 @@ cdef class Encoder: E_UPDATE_TAIL(self, n) return 1 - cdef int _encode_string(self, x) except 0: + cdef int _encode_bytes(self, x) except 0: cdef int n cdef Py_ssize_t x_len x_len = PyBytes_GET_SIZE(x) @@ -341,6 +373,12 @@ cdef class Encoder: E_UPDATE_TAIL(self, n + x_len) return 1 + cdef int _encode_string(self, x) except 0: + if self._bytestring_encoding is None: + raise TypeError("string found but no encoding specified. " + "Use bencode_utf8 rather bencode?") + return self._encode_bytes(x.encode(self._bytestring_encoding)) + cdef int _encode_list(self, x) except 0: self._ensure_buffer(1) self.tail[0] = c'l' @@ -362,7 +400,7 @@ cdef class Encoder: for k in sorted(x): if not PyBytes_CheckExact(k): raise TypeError('key in dict should be string') - self._encode_string(k) + self._encode_bytes(k) self.process(x[k]) self._ensure_buffer(1) @@ -374,7 +412,7 @@ cdef class Encoder: BrzPy_EnterRecursiveCall(" while bencode encoding") try: if PyBytes_CheckExact(x): - self._encode_string(x) + self._encode_bytes(x) elif PyInt_CheckExact(x) and x.bit_length() < 32: self._encode_int(x) elif PyLong_CheckExact(x): @@ -385,6 +423,8 @@ cdef class Encoder: self._encode_dict(x) elif PyBool_Check(x): self._encode_int(int(x)) + elif PyUnicode_Check(x): + self._encode_string(x) elif isinstance(x, Bencached): self._append_string(x.bencoded) else: @@ -394,7 +434,17 @@ cdef class Encoder: def bencode(x): - """Encode Python object x to string""" + """Encode Python object x to bytestring""" encoder = Encoder() encoder.process(x) return encoder.to_bytes() + + +def bencode_utf8(x): + """Encode Python object x to bytestring. + + Encode any strings as UTF8 + """ + encoder = Encoder(bytestring_encoding='utf-8') + encoder.process(x) + return encoder.to_bytes() diff --git a/fastbencode/tests/test_bencode.py b/fastbencode/tests/test_bencode.py index 67f7d5c..314967a 100644 --- a/fastbencode/tests/test_bencode.py +++ b/fastbencode/tests/test_bencode.py @@ -351,6 +351,47 @@ def test_decoder_type_error(self): self.assertRaises(TypeError, self.module.bdecode, 1) +class TestBdecodeUtf8(TestCase): + + module = None + + def _check(self, expected, source): + self.assertEqual(expected, self.module.bdecode_utf8(source)) + + def _run_check_error(self, exc, bad): + """Check that bdecoding a string raises a particular exception.""" + self.assertRaises(exc, self.module.bdecode_utf8, bad) + + def test_string(self): + self._check('', b'0:') + self._check('aäc', b'4:a\xc3\xa4c') + self._check('1234567890', b'10:1234567890') + + def test_large_string(self): + self.assertRaises( + ValueError, self.module.bdecode_utf8, b"2147483639:foo") + + def test_malformed_string(self): + self._run_check_error(ValueError, b'10:x') + self._run_check_error(ValueError, b'10:') + self._run_check_error(ValueError, b'10') + self._run_check_error(ValueError, b'01:x') + self._run_check_error(ValueError, b'00:') + self._run_check_error(ValueError, b'35208734823ljdahflajhdf') + self._run_check_error(ValueError, b'432432432432432:foo') + self._run_check_error(ValueError, b' 1:x') # leading whitespace + self._run_check_error(ValueError, b'-1:x') # negative + self._run_check_error(ValueError, b'1 x') # space vs colon + self._run_check_error(ValueError, b'1x') # missing colon + self._run_check_error(ValueError, (b'1' * 1000) + b':') + + def test_empty_string(self): + self.assertRaises(ValueError, self.module.bdecode_utf8, b'') + + def test_invalid_utf8(self): + self._run_check_error(UnicodeDecodeError, b'3:\xff\xfe\xfd') + + class TestBencodeEncode(TestCase): module = None @@ -414,3 +455,32 @@ def test_invalid_dict(self): def test_bool(self): self._check(b'i1e', True) self._check(b'i0e', False) + + +class TestBencodeEncodeUtf8(TestCase): + + module = None + + def _check(self, expected, source): + self.assertEqual(expected, self.module.bencode_utf8(source)) + + def test_string(self): + self._check(b'0:', '') + self._check(b'3:abc', 'abc') + self._check(b'10:1234567890', '1234567890') + + def test_list(self): + self._check(b'le', []) + self._check(b'li1ei2ei3ee', [1, 2, 3]) + self._check(b'll5:Alice3:Bobeli2ei3eee', [['Alice', 'Bob'], [2, 3]]) + + def test_list_as_tuple(self): + self._check(b'le', ()) + self._check(b'li1ei2ei3ee', (1, 2, 3)) + self._check(b'll5:Alice3:Bobeli2ei3eee', (('Alice', 'Bob'), (2, 3))) + + def test_dict(self): + self._check(b'de', {}) + self._check(b'd3:agei25e4:eyes4:bluee', {b'age': 25, b'eyes': 'blue'}) + self._check(b'd8:spam.mp3d6:author5:Alice6:lengthi100000eee', + {b'spam.mp3': {b'author': b'Alice', b'length': 100000}})