diff --git a/fastbencode/_bencode_py.py b/fastbencode/_bencode_py.py index 6bd4ef3..279d545 100644 --- a/fastbencode/_bencode_py.py +++ b/fastbencode/_bencode_py.py @@ -21,13 +21,14 @@ class BDecoder: - def __init__(self, yield_tuples=False) -> None: + def __init__(self, yield_tuples=False, bytestring_encoding=None) -> None: """Constructor. :param yield_tuples: if true, decode "l" elements as tuples rather than lists. """ self.yield_tuples = yield_tuples + self.bytestring_encoding = bytestring_encoding decode_func = {} decode_func[b'l'] = self.decode_list decode_func[b'd'] = self.decode_dict @@ -60,7 +61,10 @@ def decode_bytes(self, x, f): if x[f:f + 1] == b'0' and colon != f + 1: raise ValueError colon += 1 - return (x[colon:colon + n], colon + n) + d = x[colon:colon + n] + if self.bytestring_encoding: + d = d.decode(self.bytestring_encoding) + return (d, colon + n) def decode_list(self, x, f): r, f = [], f + 1 @@ -100,6 +104,9 @@ def bdecode(self, x): _tuple_decoder = BDecoder(True) bdecode_as_tuple = _tuple_decoder.bdecode +_utf8_decoder = BDecoder(bytestring_encoding='utf-8') +bdecode_utf8 = _utf8_decoder.bdecode + class Bencached: __slots__ = ['bencoded'] diff --git a/fastbencode/_bencode_pyx.pyx b/fastbencode/_bencode_pyx.pyx index 0ca90dd..025bca4 100644 --- a/fastbencode/_bencode_pyx.pyx +++ b/fastbencode/_bencode_pyx.pyx @@ -46,15 +46,21 @@ from cpython.mem cimport ( PyMem_Malloc, PyMem_Realloc, ) +from cpython.unicode cimport ( + PyUnicode_FromEncodedObject, + PyUnicode_FromStringAndSize, + ) from cpython.tuple cimport ( PyTuple_CheckExact, ) from libc.stdlib cimport ( strtol, + free, ) from libc.string cimport ( memcpy, + strdup, ) cdef extern from "python-compat.h": @@ -79,9 +85,10 @@ cdef class Decoder: cdef readonly char *tail cdef readonly int size cdef readonly int _yield_tuples + cdef readonly char *_bytestring_encoding cdef object text - def __init__(self, s, yield_tuples=0): + def __init__(self, s, yield_tuples=0, str bytestring_encoding=None): """Initialize decoder engine. @param s: Python string. """ @@ -92,6 +99,13 @@ cdef class Decoder: self.tail = PyBytes_AS_STRING(s) self.size = PyBytes_GET_SIZE(s) self._yield_tuples = int(yield_tuples) + if bytestring_encoding is None: + self._bytestring_encoding = NULL + else: + self._bytestring_encoding = strdup(bytestring_encoding.encode('utf-8')) + + def __dealloc__(self): + free(self._bytestring_encoding) def decode(self): result = self._decode_object() @@ -171,13 +185,22 @@ cdef class Decoder: raise ValueError('leading zeros are not allowed') D_UPDATE_TAIL(self, next_tail - self.tail + 1) if n == 0: - return b'' + if self._bytestring_encoding == NULL: + return b'' + else: + return '' if n > self.size: raise ValueError('stream underflow') if n < 0: raise ValueError('string size below zero: %d' % n) - result = PyBytes_FromStringAndSize(self.tail, n) + if self._bytestring_encoding == NULL: + result = PyBytes_FromStringAndSize(self.tail, n) + elif self._bytestring_encoding == b'utf-8': + result = PyUnicode_FromStringAndSize(self.tail, n) + else: + result = PyBytes_FromStringAndSize(self.tail, n) + result = PyUnicode_FromEncodedObject(result, self._bytestring_encoding, NULL) D_UPDATE_TAIL(self, n) return result @@ -235,6 +258,11 @@ def bdecode_as_tuple(object s): return Decoder(s, True).decode() +def bdecode_utf8(object s): + """Decode string x to Python object, decoding bytestrings as UTF8 strings.""" + return Decoder(s, bytestring_encoding='utf-8').decode() + + class Bencached(object): __slots__ = ['bencoded'] diff --git a/fastbencode/tests/test_bencode.py b/fastbencode/tests/test_bencode.py index 61cd8b5..7073f42 100644 --- a/fastbencode/tests/test_bencode.py +++ b/fastbencode/tests/test_bencode.py @@ -351,6 +351,46 @@ def test_decoder_type_error(self): self.assertRaises(TypeError, self.module.bdecode, 1) +class TestBdecodeUtf8(TestCase): + + module = None + + def _check(self, expected, source): + self.assertEqual(expected, self.module.bdecode_utf8(source)) + + def _run_check_error(self, exc, bad): + """Check that bdecoding a string raises a particular exception.""" + self.assertRaises(exc, self.module.bdecode_utf8, bad) + + def test_string(self): + self._check('', b'0:') + self._check('aäc', b'4:a\xc3\xa4c') + self._check('1234567890', b'10:1234567890') + + def test_large_string(self): + self.assertRaises(ValueError, self.module.bdecode_utf8, b"2147483639:foo") + + def test_malformed_string(self): + self._run_check_error(ValueError, b'10:x') + self._run_check_error(ValueError, b'10:') + self._run_check_error(ValueError, b'10') + self._run_check_error(ValueError, b'01:x') + self._run_check_error(ValueError, b'00:') + self._run_check_error(ValueError, b'35208734823ljdahflajhdf') + self._run_check_error(ValueError, b'432432432432432:foo') + self._run_check_error(ValueError, b' 1:x') # leading whitespace + self._run_check_error(ValueError, b'-1:x') # negative + self._run_check_error(ValueError, b'1 x') # space vs colon + self._run_check_error(ValueError, b'1x') # missing colon + self._run_check_error(ValueError, (b'1' * 1000) + b':') + + def test_empty_string(self): + self.assertRaises(ValueError, self.module.bdecode_utf8, b'') + + def test_invalid_utf8(self): + self._run_check_error(UnicodeDecodeError, b'3:\xff\xfe\xfd') + + class TestBencodeEncode(TestCase): module = None