From e942f8e40d6f9bc629c62bcdc2d2c27215640e13 Mon Sep 17 00:00:00 2001 From: tychovrahe Date: Tue, 3 Dec 2024 19:41:31 +0100 Subject: [PATCH 1/5] refactor(core): split polling can_read and reading from USB [no changelog] --- .../upymod/modtrezorio/modtrezorio-hid.h | 15 ++++++++++ .../upymod/modtrezorio/modtrezorio-poll.h | 21 ++++---------- .../upymod/modtrezorio/modtrezorio-webusb.h | 15 ++++++++++ core/mocks/generated/trezorio/__init__.pyi | 10 +++++++ core/src/apps/webauthn/fido2.py | 12 ++++++-- core/src/trezor/wire/codec/codec_v1.py | 12 ++++++-- core/tests/test_trezor.wire.codec.codec_v1.py | 28 +++++++++++++------ 7 files changed, 85 insertions(+), 28 deletions(-) diff --git a/core/embed/upymod/modtrezorio/modtrezorio-hid.h b/core/embed/upymod/modtrezorio/modtrezorio-hid.h index bcf141d7765..cedc2b437d2 100644 --- a/core/embed/upymod/modtrezorio/modtrezorio-hid.h +++ b/core/embed/upymod/modtrezorio/modtrezorio-hid.h @@ -141,6 +141,20 @@ STATIC mp_obj_t mod_trezorio_HID_write(mp_obj_t self, mp_obj_t msg) { STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorio_HID_write_obj, mod_trezorio_HID_write); +/// def read(self, buf: bytes) -> int: +/// """ +/// Reads message using USB HID (device) or UDP (emulator). +/// """ +STATIC mp_obj_t mod_trezorio_HID_read(mp_obj_t self, mp_obj_t buffer) { + mp_obj_HID_t *o = MP_OBJ_TO_PTR(self); + mp_buffer_info_t buf = {0}; + mp_get_buffer_raise(buffer, &buf, MP_BUFFER_WRITE); + ssize_t r = usb_hid_read(o->info.iface_num, buf.buf, buf.len); + return MP_OBJ_NEW_SMALL_INT(r); +} +STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorio_HID_read_obj, + mod_trezorio_HID_read); + /// def write_blocking(self, msg: bytes, timeout_ms: int) -> int: /// """ /// Sends message using USB HID (device) or UDP (emulator). @@ -162,6 +176,7 @@ STATIC const mp_rom_map_elem_t mod_trezorio_HID_locals_dict_table[] = { {MP_ROM_QSTR(MP_QSTR_iface_num), MP_ROM_PTR(&mod_trezorio_HID_iface_num_obj)}, {MP_ROM_QSTR(MP_QSTR_write), MP_ROM_PTR(&mod_trezorio_HID_write_obj)}, + {MP_ROM_QSTR(MP_QSTR_read), MP_ROM_PTR(&mod_trezorio_HID_read_obj)}, {MP_ROM_QSTR(MP_QSTR_write_blocking), MP_ROM_PTR(&mod_trezorio_HID_write_blocking_obj)}, }; diff --git a/core/embed/upymod/modtrezorio/modtrezorio-poll.h b/core/embed/upymod/modtrezorio/modtrezorio-poll.h index 2d97e491e7f..3e0a738abdb 100644 --- a/core/embed/upymod/modtrezorio/modtrezorio-poll.h +++ b/core/embed/upymod/modtrezorio/modtrezorio-poll.h @@ -166,22 +166,11 @@ STATIC mp_obj_t mod_trezorio_poll(mp_obj_t ifaces, mp_obj_t list_ref, } #endif else if (mode == POLL_READ) { - if (sectrue == usb_hid_can_read(iface)) { - uint8_t buf[64] = {0}; - int len = usb_hid_read(iface, buf, sizeof(buf)); - if (len > 0) { - ret->items[0] = MP_OBJ_NEW_SMALL_INT(i); - ret->items[1] = mp_obj_new_bytes(buf, len); - return mp_const_true; - } - } else if (sectrue == usb_webusb_can_read(iface)) { - uint8_t buf[64] = {0}; - int len = usb_webusb_read(iface, buf, sizeof(buf)); - if (len > 0) { - ret->items[0] = MP_OBJ_NEW_SMALL_INT(i); - ret->items[1] = mp_obj_new_bytes(buf, len); - return mp_const_true; - } + if ((sectrue == usb_hid_can_read(iface)) || + (sectrue == usb_webusb_can_read(iface))) { + ret->items[0] = MP_OBJ_NEW_SMALL_INT(i); + ret->items[1] = MP_OBJ_NEW_SMALL_INT(64); + return mp_const_true; } } else if (mode == POLL_WRITE) { if (sectrue == usb_hid_can_write(iface)) { diff --git a/core/embed/upymod/modtrezorio/modtrezorio-webusb.h b/core/embed/upymod/modtrezorio/modtrezorio-webusb.h index d893b107172..4773be07641 100644 --- a/core/embed/upymod/modtrezorio/modtrezorio-webusb.h +++ b/core/embed/upymod/modtrezorio/modtrezorio-webusb.h @@ -127,10 +127,25 @@ STATIC mp_obj_t mod_trezorio_WebUSB_write(mp_obj_t self, mp_obj_t msg) { STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorio_WebUSB_write_obj, mod_trezorio_WebUSB_write); +/// def read(self, buf: bytes) -> int: +/// """ +/// Reads message using WebUSB (device) or UDP (emulator). +/// """ +STATIC mp_obj_t mod_trezorio_WebUSB_read(mp_obj_t self, mp_obj_t buffer) { + mp_obj_HID_t *o = MP_OBJ_TO_PTR(self); + mp_buffer_info_t buf = {0}; + mp_get_buffer_raise(buffer, &buf, MP_BUFFER_WRITE); + ssize_t r = usb_webusb_read(o->info.iface_num, buf.buf, buf.len); + return MP_OBJ_NEW_SMALL_INT(r); +} +STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorio_WebUSB_read_obj, + mod_trezorio_WebUSB_read); + STATIC const mp_rom_map_elem_t mod_trezorio_WebUSB_locals_dict_table[] = { {MP_ROM_QSTR(MP_QSTR_iface_num), MP_ROM_PTR(&mod_trezorio_WebUSB_iface_num_obj)}, {MP_ROM_QSTR(MP_QSTR_write), MP_ROM_PTR(&mod_trezorio_WebUSB_write_obj)}, + {MP_ROM_QSTR(MP_QSTR_read), MP_ROM_PTR(&mod_trezorio_WebUSB_read_obj)}, }; STATIC MP_DEFINE_CONST_DICT(mod_trezorio_WebUSB_locals_dict, mod_trezorio_WebUSB_locals_dict_table); diff --git a/core/mocks/generated/trezorio/__init__.pyi b/core/mocks/generated/trezorio/__init__.pyi index efb11e08e96..8f34d9d1f01 100644 --- a/core/mocks/generated/trezorio/__init__.pyi +++ b/core/mocks/generated/trezorio/__init__.pyi @@ -32,6 +32,11 @@ class HID: Sends message using USB HID (device) or UDP (emulator). """ + def read(self, buf: bytes) -> int: + """ + Reads message using USB HID (device) or UDP (emulator). + """ + def write_blocking(self, msg: bytes, timeout_ms: int) -> int: """ Sends message using USB HID (device) or UDP (emulator). @@ -148,6 +153,11 @@ class WebUSB: """ Sends message using USB WebUSB (device) or UDP (emulator). """ + + def read(self, buf: bytes) -> int: + """ + Reads message using WebUSB (device) or UDP (emulator). + """ from . import fatfs, haptic, sdcard POLL_READ: int # wait until interface is readable and return read data POLL_WRITE: int # wait until interface is writable diff --git a/core/src/apps/webauthn/fido2.py b/core/src/apps/webauthn/fido2.py index 5a1bedd4ae7..2125ec86997 100644 --- a/core/src/apps/webauthn/fido2.py +++ b/core/src/apps/webauthn/fido2.py @@ -376,7 +376,11 @@ async def _read_cmd(iface: HID) -> Cmd | None: read = loop.wait(iface.iface_num() | io.POLL_READ) # wait for incoming command indefinitely - buf = await read + msg_len = await read + buf = bytearray(msg_len) + read_len = iface.read(buf) + if read_len != msg_len: + raise ValueError("Invalid length") while True: ifrm = overlay_struct(bytearray(buf), desc_init) bcnt = ifrm.bcnt @@ -415,7 +419,11 @@ async def _read_cmd(iface: HID) -> Cmd | None: read.timeout_ms = _CTAP_HID_TIMEOUT_MS while datalen < bcnt: try: - buf = await read + msg_len = await read + buf = bytearray(msg_len) + read_len = iface.read(buf) + if read_len != msg_len: + raise ValueError("Invalid length") except loop.Timeout: if __debug__: warning(__name__, "_ERR_MSG_TIMEOUT") diff --git a/core/src/trezor/wire/codec/codec_v1.py b/core/src/trezor/wire/codec/codec_v1.py index 02ff37f0eaf..68a2797ff56 100644 --- a/core/src/trezor/wire/codec/codec_v1.py +++ b/core/src/trezor/wire/codec/codec_v1.py @@ -25,7 +25,11 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag read = loop.wait(iface.iface_num() | io.POLL_READ) # wait for initial report - report = await read + msg_len = await read + report = bytearray(msg_len) + read_len = iface.read(report) + if read_len != msg_len: + raise CodecError("Invalid length") if report[0] != _REP_MARKER: raise CodecError("Invalid magic") _, magic1, magic2, mtype, msize = ustruct.unpack(_REP_INIT, report) @@ -50,7 +54,11 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag while nread < msize: # wait for continuation report - report = await read + msg_len = await read + report = bytearray(msg_len) + read_len = iface.read(report) + if read_len != msg_len: + raise CodecError("Invalid length") if report[0] != _REP_MARKER: raise CodecError("Invalid magic") diff --git a/core/tests/test_trezor.wire.codec.codec_v1.py b/core/tests/test_trezor.wire.codec.codec_v1.py index 78675859e2c..5a73467e258 100644 --- a/core/tests/test_trezor.wire.codec.codec_v1.py +++ b/core/tests/test_trezor.wire.codec.codec_v1.py @@ -12,6 +12,7 @@ class MockHID: def __init__(self, num): self.num = num self.data = [] + self.packet = None def iface_num(self): return self.num @@ -20,6 +21,17 @@ def write(self, msg): self.data.append(bytearray(msg)) return len(msg) + def mock_read(self, packet, gen): + self.packet = packet + return gen.send(len(packet)) + + def read(self, buffer): + if self.packet is None: + raise Exception("No packet to read") + buffer[:] = self.packet + self.packet = None + return len(buffer) + def wait_object(self, mode): return wait(mode | self.num) @@ -48,7 +60,7 @@ def test_read_one_packet(self): self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) with self.assertRaises(StopIteration) as e: - gen.send(message_packet) + self.interface.mock_read(message_packet, gen) # e.value is StopIteration. e.value.value is the return value of the call result = e.value.value @@ -74,11 +86,11 @@ def test_read_many_packets(self): query = gen.send(None) for packet in packets[:-1]: self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) - query = gen.send(packet) + query = self.interface.mock_read(packet, gen) # last packet will stop with self.assertRaises(StopIteration) as e: - gen.send(packets[-1]) + self.interface.mock_read(packets[-1], gen) # e.value is StopIteration. e.value.value is the return value of the call result = e.value.value @@ -103,7 +115,7 @@ def test_read_large_message(self): query = gen.send(None) self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) with self.assertRaises(StopIteration) as e: - gen.send(packet) + self.interface.mock_read(packet, gen) # e.value is StopIteration. e.value.value is the return value of the call result = e.value.value @@ -169,10 +181,10 @@ def test_roundtrip(self): query = gen.send(None) for packet in self.interface.data[:-1]: self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) - query = gen.send(packet) + query = self.interface.mock_read(packet, gen) with self.assertRaises(StopIteration) as e: - gen.send(self.interface.data[-1]) + self.interface.mock_read(self.interface.data[-1], gen) result = e.value.value self.assertEqual(result.type, MESSAGE_TYPE) @@ -194,10 +206,10 @@ def test_read_huge_packet(self): query = gen.send(None) for _ in range(PACKET_COUNT - 1): self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) - query = gen.send(packet) + query = self.interface.mock_read(packet, gen) with self.assertRaises(codec_v1.CodecError) as e: - gen.send(packet) + self.interface.mock_read(packet,gen) self.assertEqual(e.value.args[0], "Message too large") From 51556ef6ba7632ecda7feb7e9e816dea1ce02d09 Mon Sep 17 00:00:00 2001 From: tychovrahe Date: Wed, 4 Dec 2024 10:21:06 +0100 Subject: [PATCH 2/5] fixup! refactor(core): split polling can_read and reading from USB --- core/embed/io/usb/unix/usb.c | 66 ++++++++++++++----- .../upymod/modtrezorio/modtrezorio-hid.h | 28 +++++--- .../upymod/modtrezorio/modtrezorio-webusb.h | 26 ++++++-- core/mocks/generated/trezorio/__init__.pyi | 6 +- core/src/apps/webauthn/fido2.py | 4 +- core/src/trezor/wire/codec/codec_v1.py | 5 +- 6 files changed, 97 insertions(+), 38 deletions(-) diff --git a/core/embed/io/usb/unix/usb.c b/core/embed/io/usb/unix/usb.c index 8ab0d1087bc..e58f513904d 100644 --- a/core/embed/io/usb/unix/usb.c +++ b/core/embed/io/usb/unix/usb.c @@ -50,6 +50,8 @@ static struct { int sock; struct sockaddr_in si_me, si_other; socklen_t slen; + uint8_t msg[64]; + int msg_len; } usb_ifaces[USBD_MAX_NUM_INTERFACES]; secbool usb_init(const usb_dev_info_t *dev_info) { @@ -60,7 +62,9 @@ secbool usb_init(const usb_dev_info_t *dev_info) { usb_ifaces[i].sock = -1; memzero(&usb_ifaces[i].si_me, sizeof(struct sockaddr_in)); memzero(&usb_ifaces[i].si_other, sizeof(struct sockaddr_in)); + memzero(&usb_ifaces[i].msg, sizeof(usb_ifaces[i].msg)); usb_ifaces[i].slen = 0; + usb_ifaces[i].msg_len = 0; } return sectrue; } @@ -136,36 +140,66 @@ secbool usb_vcp_add(const usb_vcp_info_t *info) { return sectrue; } -static secbool usb_emulated_poll(uint8_t iface_num, short dir) { +static secbool usb_emulated_poll_read(uint8_t iface_num) { struct pollfd fds[] = { - {usb_ifaces[iface_num].sock, dir, 0}, + {usb_ifaces[iface_num].sock, POLLIN, 0}, }; - int r = poll(fds, 1, 0); - return sectrue * (r > 0); -} + int res = poll(fds, 1, 0); + + if (res <= 0) { + return secfalse; + } -static int usb_emulated_read(uint8_t iface_num, uint8_t *buf, uint32_t len) { struct sockaddr_in si; socklen_t sl = sizeof(si); - ssize_t r = recvfrom(usb_ifaces[iface_num].sock, buf, len, MSG_DONTWAIT, + ssize_t r = recvfrom(usb_ifaces[iface_num].sock, usb_ifaces[iface_num].msg, + sizeof(usb_ifaces[iface_num].msg), MSG_DONTWAIT, (struct sockaddr *)&si, &sl); - if (r < 0) { - return r; + if (r <= 0) { + return secfalse; } + usb_ifaces[iface_num].si_other = si; usb_ifaces[iface_num].slen = sl; static const char *ping_req = "PINGPING"; static const char *ping_resp = "PONGPONG"; - if (r == strlen(ping_req) && 0 == memcmp(ping_req, buf, strlen(ping_req))) { + if (r == strlen(ping_req) && + 0 == memcmp(ping_req, usb_ifaces[iface_num].msg, strlen(ping_req))) { if (usb_ifaces[iface_num].slen > 0) { sendto(usb_ifaces[iface_num].sock, ping_resp, strlen(ping_resp), MSG_DONTWAIT, (const struct sockaddr *)&usb_ifaces[iface_num].si_other, usb_ifaces[iface_num].slen); } - return 0; + memzero(usb_ifaces[iface_num].msg, sizeof(usb_ifaces[iface_num].msg)); + return secfalse; } - return r; + + usb_ifaces[iface_num].msg_len = r; + + return sectrue; +} + +static secbool usb_emulated_poll_write(uint8_t iface_num) { + struct pollfd fds[] = { + {usb_ifaces[iface_num].sock, POLLOUT, 0}, + }; + int r = poll(fds, 1, 0); + return sectrue * (r > 0); +} + +static int usb_emulated_read(uint8_t iface_num, uint8_t *buf, uint32_t len) { + if (usb_ifaces[iface_num].msg_len > 0) { + if (usb_ifaces[iface_num].msg_len < len) { + len = usb_ifaces[iface_num].msg_len; + } + memcpy(buf, usb_ifaces[iface_num].msg, len); + usb_ifaces[iface_num].msg_len = 0; + memzero(usb_ifaces[iface_num].msg, sizeof(usb_ifaces[iface_num].msg)); + return len; + } + + return 0; } static int usb_emulated_write(uint8_t iface_num, const uint8_t *buf, @@ -184,7 +218,7 @@ secbool usb_hid_can_read(uint8_t iface_num) { usb_ifaces[iface_num].type != USB_IFACE_TYPE_HID) { return secfalse; } - return usb_emulated_poll(iface_num, POLLIN); + return usb_emulated_poll_read(iface_num); } secbool usb_webusb_can_read(uint8_t iface_num) { @@ -192,7 +226,7 @@ secbool usb_webusb_can_read(uint8_t iface_num) { usb_ifaces[iface_num].type != USB_IFACE_TYPE_WEBUSB) { return secfalse; } - return usb_emulated_poll(iface_num, POLLIN); + return usb_emulated_poll_read(iface_num); } secbool usb_hid_can_write(uint8_t iface_num) { @@ -200,7 +234,7 @@ secbool usb_hid_can_write(uint8_t iface_num) { usb_ifaces[iface_num].type != USB_IFACE_TYPE_HID) { return secfalse; } - return usb_emulated_poll(iface_num, POLLOUT); + return usb_emulated_poll_write(iface_num); } secbool usb_webusb_can_write(uint8_t iface_num) { @@ -208,7 +242,7 @@ secbool usb_webusb_can_write(uint8_t iface_num) { usb_ifaces[iface_num].type != USB_IFACE_TYPE_WEBUSB) { return secfalse; } - return usb_emulated_poll(iface_num, POLLOUT); + return usb_emulated_poll_write(iface_num); } int usb_hid_read(uint8_t iface_num, uint8_t *buf, uint32_t len) { diff --git a/core/embed/upymod/modtrezorio/modtrezorio-hid.h b/core/embed/upymod/modtrezorio/modtrezorio-hid.h index cedc2b437d2..9b6e7c0a375 100644 --- a/core/embed/upymod/modtrezorio/modtrezorio-hid.h +++ b/core/embed/upymod/modtrezorio/modtrezorio-hid.h @@ -141,19 +141,31 @@ STATIC mp_obj_t mod_trezorio_HID_write(mp_obj_t self, mp_obj_t msg) { STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorio_HID_write_obj, mod_trezorio_HID_write); -/// def read(self, buf: bytes) -> int: +/// def read(self, buf: bytes, offset: int = 0, limit: int | None = None) -> int /// """ -/// Reads message using USB HID (device) or UDP (emulator). +/// Reads message using HID (device) or UDP (emulator). /// """ -STATIC mp_obj_t mod_trezorio_HID_read(mp_obj_t self, mp_obj_t buffer) { - mp_obj_HID_t *o = MP_OBJ_TO_PTR(self); +STATIC mp_obj_t mod_trezorio_HID_read(size_t n_args, const mp_obj_t *args) { + mp_obj_HID_t *o = MP_OBJ_TO_PTR(args[0]); mp_buffer_info_t buf = {0}; - mp_get_buffer_raise(buffer, &buf, MP_BUFFER_WRITE); - ssize_t r = usb_hid_read(o->info.iface_num, buf.buf, buf.len); + mp_get_buffer_raise(args[1], &buf, MP_BUFFER_WRITE); + + int offset = mp_obj_get_int(args[2]); + + int len = buf.len - offset; + if (n_args >= 3) { + int limit = mp_obj_get_int(args[3]); + if ((limit - offset) < len) { + len = (limit - offset); + } + } + + ssize_t r = + usb_hid_read(o->info.iface_num, &((uint8_t *)buf.buf)[offset], len); return MP_OBJ_NEW_SMALL_INT(r); } -STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorio_HID_read_obj, - mod_trezorio_HID_read); +STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorio_HID_read_obj, 3, 4, + mod_trezorio_HID_read); /// def write_blocking(self, msg: bytes, timeout_ms: int) -> int: /// """ diff --git a/core/embed/upymod/modtrezorio/modtrezorio-webusb.h b/core/embed/upymod/modtrezorio/modtrezorio-webusb.h index 4773be07641..bb3c639f14a 100644 --- a/core/embed/upymod/modtrezorio/modtrezorio-webusb.h +++ b/core/embed/upymod/modtrezorio/modtrezorio-webusb.h @@ -127,19 +127,31 @@ STATIC mp_obj_t mod_trezorio_WebUSB_write(mp_obj_t self, mp_obj_t msg) { STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorio_WebUSB_write_obj, mod_trezorio_WebUSB_write); -/// def read(self, buf: bytes) -> int: +/// def read(self, buf: bytes, offset: int = 0, limit: int | None = None) -> int /// """ /// Reads message using WebUSB (device) or UDP (emulator). /// """ -STATIC mp_obj_t mod_trezorio_WebUSB_read(mp_obj_t self, mp_obj_t buffer) { - mp_obj_HID_t *o = MP_OBJ_TO_PTR(self); +STATIC mp_obj_t mod_trezorio_WebUSB_read(size_t n_args, const mp_obj_t *args) { + mp_obj_WebUSB_t *o = MP_OBJ_TO_PTR(args[0]); mp_buffer_info_t buf = {0}; - mp_get_buffer_raise(buffer, &buf, MP_BUFFER_WRITE); - ssize_t r = usb_webusb_read(o->info.iface_num, buf.buf, buf.len); + mp_get_buffer_raise(args[1], &buf, MP_BUFFER_WRITE); + + int offset = mp_obj_get_int(args[2]); + + int len = buf.len - offset; + if (n_args >= 3) { + int limit = mp_obj_get_int(args[3]); + if ((limit - offset) < len) { + len = (limit - offset); + } + } + + ssize_t r = + usb_webusb_read(o->info.iface_num, &((uint8_t *)buf.buf)[offset], len); return MP_OBJ_NEW_SMALL_INT(r); } -STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorio_WebUSB_read_obj, - mod_trezorio_WebUSB_read); +STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorio_WebUSB_read_obj, 3, 4, + mod_trezorio_WebUSB_read); STATIC const mp_rom_map_elem_t mod_trezorio_WebUSB_locals_dict_table[] = { {MP_ROM_QSTR(MP_QSTR_iface_num), diff --git a/core/mocks/generated/trezorio/__init__.pyi b/core/mocks/generated/trezorio/__init__.pyi index 8f34d9d1f01..c828c552de0 100644 --- a/core/mocks/generated/trezorio/__init__.pyi +++ b/core/mocks/generated/trezorio/__init__.pyi @@ -32,9 +32,9 @@ class HID: Sends message using USB HID (device) or UDP (emulator). """ - def read(self, buf: bytes) -> int: + def read(self, buf: bytes, offset: int = 0, limit: int | None = None) -> int """ - Reads message using USB HID (device) or UDP (emulator). + Reads message using HID (device) or UDP (emulator). """ def write_blocking(self, msg: bytes, timeout_ms: int) -> int: @@ -154,7 +154,7 @@ class WebUSB: Sends message using USB WebUSB (device) or UDP (emulator). """ - def read(self, buf: bytes) -> int: + def read(self, buf: bytes, offset: int = 0, limit: int | None = None) -> int """ Reads message using WebUSB (device) or UDP (emulator). """ diff --git a/core/src/apps/webauthn/fido2.py b/core/src/apps/webauthn/fido2.py index 2125ec86997..5ee830ad7d4 100644 --- a/core/src/apps/webauthn/fido2.py +++ b/core/src/apps/webauthn/fido2.py @@ -378,7 +378,7 @@ async def _read_cmd(iface: HID) -> Cmd | None: # wait for incoming command indefinitely msg_len = await read buf = bytearray(msg_len) - read_len = iface.read(buf) + read_len = iface.read(buf, 0, msg_len) if read_len != msg_len: raise ValueError("Invalid length") while True: @@ -421,7 +421,7 @@ async def _read_cmd(iface: HID) -> Cmd | None: try: msg_len = await read buf = bytearray(msg_len) - read_len = iface.read(buf) + read_len = iface.read(buf, 0, msg_len) if read_len != msg_len: raise ValueError("Invalid length") except loop.Timeout: diff --git a/core/src/trezor/wire/codec/codec_v1.py b/core/src/trezor/wire/codec/codec_v1.py index 68a2797ff56..7134222b7c5 100644 --- a/core/src/trezor/wire/codec/codec_v1.py +++ b/core/src/trezor/wire/codec/codec_v1.py @@ -27,8 +27,9 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag # wait for initial report msg_len = await read report = bytearray(msg_len) - read_len = iface.read(report) + read_len = iface.read(report, 0, msg_len) if read_len != msg_len: + print("read_len", read_len, "msg_len", msg_len) raise CodecError("Invalid length") if report[0] != _REP_MARKER: raise CodecError("Invalid magic") @@ -56,7 +57,7 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag # wait for continuation report msg_len = await read report = bytearray(msg_len) - read_len = iface.read(report) + read_len = iface.read(report, 0, msg_len) if read_len != msg_len: raise CodecError("Invalid length") if report[0] != _REP_MARKER: From 34559a47716ad0eb41603d66ac4ebe9ed1e6841c Mon Sep 17 00:00:00 2001 From: tychovrahe Date: Wed, 4 Dec 2024 13:20:25 +0100 Subject: [PATCH 3/5] fixup! refactor(core): split polling can_read and reading from USB --- .../upymod/modtrezorio/modtrezorio-hid.h | 16 +++++++++------- .../upymod/modtrezorio/modtrezorio-webusb.h | 18 ++++++++++-------- core/tests/test_trezor.wire.codec.codec_v1.py | 19 +++++++++++++++---- 3 files changed, 34 insertions(+), 19 deletions(-) diff --git a/core/embed/upymod/modtrezorio/modtrezorio-hid.h b/core/embed/upymod/modtrezorio/modtrezorio-hid.h index 9b6e7c0a375..ee92df81c1d 100644 --- a/core/embed/upymod/modtrezorio/modtrezorio-hid.h +++ b/core/embed/upymod/modtrezorio/modtrezorio-hid.h @@ -150,18 +150,20 @@ STATIC mp_obj_t mod_trezorio_HID_read(size_t n_args, const mp_obj_t *args) { mp_buffer_info_t buf = {0}; mp_get_buffer_raise(args[1], &buf, MP_BUFFER_WRITE); - int offset = mp_obj_get_int(args[2]); + int offset = 0; + if (n_args >= 2) { + offset = mp_obj_get_int(args[2]); + } - int len = buf.len - offset; + int limit; if (n_args >= 3) { - int limit = mp_obj_get_int(args[3]); - if ((limit - offset) < len) { - len = (limit - offset); - } + limit = mp_obj_get_int(args[3]); + } else { + limit = buf.len - offset; } ssize_t r = - usb_hid_read(o->info.iface_num, &((uint8_t *)buf.buf)[offset], len); + usb_hid_read(o->info.iface_num, &((uint8_t *)buf.buf)[offset], limit); return MP_OBJ_NEW_SMALL_INT(r); } STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorio_HID_read_obj, 3, 4, diff --git a/core/embed/upymod/modtrezorio/modtrezorio-webusb.h b/core/embed/upymod/modtrezorio/modtrezorio-webusb.h index bb3c639f14a..20b69c99f88 100644 --- a/core/embed/upymod/modtrezorio/modtrezorio-webusb.h +++ b/core/embed/upymod/modtrezorio/modtrezorio-webusb.h @@ -136,21 +136,23 @@ STATIC mp_obj_t mod_trezorio_WebUSB_read(size_t n_args, const mp_obj_t *args) { mp_buffer_info_t buf = {0}; mp_get_buffer_raise(args[1], &buf, MP_BUFFER_WRITE); - int offset = mp_obj_get_int(args[2]); + int offset = 0; + if (n_args >= 2) { + offset = mp_obj_get_int(args[2]); + } - int len = buf.len - offset; + int limit; if (n_args >= 3) { - int limit = mp_obj_get_int(args[3]); - if ((limit - offset) < len) { - len = (limit - offset); - } + limit = mp_obj_get_int(args[3]); + } else { + limit = buf.len - offset; } ssize_t r = - usb_webusb_read(o->info.iface_num, &((uint8_t *)buf.buf)[offset], len); + usb_webusb_read(o->info.iface_num, &((uint8_t *)buf.buf)[offset], limit); return MP_OBJ_NEW_SMALL_INT(r); } -STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorio_WebUSB_read_obj, 3, 4, +STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorio_WebUSB_read_obj, 2, 4, mod_trezorio_WebUSB_read); STATIC const mp_rom_map_elem_t mod_trezorio_WebUSB_locals_dict_table[] = { diff --git a/core/tests/test_trezor.wire.codec.codec_v1.py b/core/tests/test_trezor.wire.codec.codec_v1.py index 5a73467e258..b4551077aa7 100644 --- a/core/tests/test_trezor.wire.codec.codec_v1.py +++ b/core/tests/test_trezor.wire.codec.codec_v1.py @@ -25,12 +25,23 @@ def mock_read(self, packet, gen): self.packet = packet return gen.send(len(packet)) - def read(self, buffer): + def read(self, buffer, offset=0, limit=None): if self.packet is None: raise Exception("No packet to read") - buffer[:] = self.packet - self.packet = None - return len(buffer) + if limit is None: + limit = len(buffer) - offset + + if len(self.packet) > limit: + end = offset + limit + buffer[offset:end] = self.packet[:limit] + self.packet = None + return limit + else: + end = offset + len(self.packet) + buffer[offset:end] = self.packet + read = len(self.packet) + self.packet = None + return read def wait_object(self, mode): return wait(mode | self.num) From c9188fbcd96c366043c49462840762cd41226d93 Mon Sep 17 00:00:00 2001 From: tychovrahe Date: Wed, 4 Dec 2024 14:26:29 +0100 Subject: [PATCH 4/5] fixup! refactor(core): split polling can_read and reading from USB --- core/embed/upymod/modtrezorio/modtrezorio-hid.h | 2 +- core/embed/upymod/modtrezorio/modtrezorio-webusb.h | 2 +- core/mocks/generated/trezorio/__init__.pyi | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/core/embed/upymod/modtrezorio/modtrezorio-hid.h b/core/embed/upymod/modtrezorio/modtrezorio-hid.h index ee92df81c1d..6f1c1bd2cd5 100644 --- a/core/embed/upymod/modtrezorio/modtrezorio-hid.h +++ b/core/embed/upymod/modtrezorio/modtrezorio-hid.h @@ -143,7 +143,7 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorio_HID_write_obj, /// def read(self, buf: bytes, offset: int = 0, limit: int | None = None) -> int /// """ -/// Reads message using HID (device) or UDP (emulator). +/// Reads message using USB HID (device) or UDP (emulator). /// """ STATIC mp_obj_t mod_trezorio_HID_read(size_t n_args, const mp_obj_t *args) { mp_obj_HID_t *o = MP_OBJ_TO_PTR(args[0]); diff --git a/core/embed/upymod/modtrezorio/modtrezorio-webusb.h b/core/embed/upymod/modtrezorio/modtrezorio-webusb.h index 20b69c99f88..a50e03dded0 100644 --- a/core/embed/upymod/modtrezorio/modtrezorio-webusb.h +++ b/core/embed/upymod/modtrezorio/modtrezorio-webusb.h @@ -129,7 +129,7 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorio_WebUSB_write_obj, /// def read(self, buf: bytes, offset: int = 0, limit: int | None = None) -> int /// """ -/// Reads message using WebUSB (device) or UDP (emulator). +/// Reads message using USB WebUSB (device) or UDP (emulator). /// """ STATIC mp_obj_t mod_trezorio_WebUSB_read(size_t n_args, const mp_obj_t *args) { mp_obj_WebUSB_t *o = MP_OBJ_TO_PTR(args[0]); diff --git a/core/mocks/generated/trezorio/__init__.pyi b/core/mocks/generated/trezorio/__init__.pyi index c828c552de0..463d7f12359 100644 --- a/core/mocks/generated/trezorio/__init__.pyi +++ b/core/mocks/generated/trezorio/__init__.pyi @@ -34,7 +34,7 @@ class HID: def read(self, buf: bytes, offset: int = 0, limit: int | None = None) -> int """ - Reads message using HID (device) or UDP (emulator). + Reads message using USB HID (device) or UDP (emulator). """ def write_blocking(self, msg: bytes, timeout_ms: int) -> int: @@ -156,7 +156,7 @@ class WebUSB: def read(self, buf: bytes, offset: int = 0, limit: int | None = None) -> int """ - Reads message using WebUSB (device) or UDP (emulator). + Reads message using USB WebUSB (device) or UDP (emulator). """ from . import fatfs, haptic, sdcard POLL_READ: int # wait until interface is readable and return read data From 7d9a53c06946b51d4738f5e0c1b8d971f8807b8c Mon Sep 17 00:00:00 2001 From: tychovrahe Date: Thu, 5 Dec 2024 14:08:07 +0100 Subject: [PATCH 5/5] fixup! refactor(core): split polling can_read and reading from USB wip --- core/embed/io/usb/inc/io/usb.h | 2 ++ .../upymod/modtrezorio/modtrezorio-hid.h | 24 +++++++++++------ .../upymod/modtrezorio/modtrezorio-poll.h | 9 +++---- .../upymod/modtrezorio/modtrezorio-webusb.h | 26 ++++++++++++------- core/mocks/generated/trezorio/__init__.pyi | 4 +-- core/src/apps/webauthn/fido2.py | 8 ++---- core/src/trezor/wire/codec/codec_v1.py | 9 ++----- core/tests/test_trezor.wire.codec.codec_v1.py | 13 ++++------ 8 files changed, 49 insertions(+), 46 deletions(-) diff --git a/core/embed/io/usb/inc/io/usb.h b/core/embed/io/usb/inc/io/usb.h index a039719dc43..2646ce97b9c 100644 --- a/core/embed/io/usb/inc/io/usb.h +++ b/core/embed/io/usb/inc/io/usb.h @@ -26,6 +26,8 @@ #include #include +#define USB_PACKET_LEN 64 + // clang-format off // // USB stack high-level state machine diff --git a/core/embed/upymod/modtrezorio/modtrezorio-hid.h b/core/embed/upymod/modtrezorio/modtrezorio-hid.h index 6f1c1bd2cd5..928914b33cf 100644 --- a/core/embed/upymod/modtrezorio/modtrezorio-hid.h +++ b/core/embed/upymod/modtrezorio/modtrezorio-hid.h @@ -141,7 +141,7 @@ STATIC mp_obj_t mod_trezorio_HID_write(mp_obj_t self, mp_obj_t msg) { STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorio_HID_write_obj, mod_trezorio_HID_write); -/// def read(self, buf: bytes, offset: int = 0, limit: int | None = None) -> int +/// def read(self, buf: bytes, offset: int = 0) -> int /// """ /// Reads message using USB HID (device) or UDP (emulator). /// """ @@ -155,15 +155,23 @@ STATIC mp_obj_t mod_trezorio_HID_read(size_t n_args, const mp_obj_t *args) { offset = mp_obj_get_int(args[2]); } - int limit; - if (n_args >= 3) { - limit = mp_obj_get_int(args[3]); - } else { - limit = buf.len - offset; + if (offset < 0) { + mp_raise_ValueError("Negative offset not allowed"); + } + + uint32_t buffer_space = buf.len - offset; + + if (buffer_space < USB_PACKET_LEN) { + mp_raise_ValueError("Buffer too small"); + } + + ssize_t r = usb_hid_read(o->info.iface_num, &((uint8_t *)buf.buf)[offset], + USB_PACKET_LEN); + + if (r != USB_PACKET_LEN) { + mp_raise_msg(&mp_type_RuntimeError, "Unexpected read length"); } - ssize_t r = - usb_hid_read(o->info.iface_num, &((uint8_t *)buf.buf)[offset], limit); return MP_OBJ_NEW_SMALL_INT(r); } STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorio_HID_read_obj, 3, 4, diff --git a/core/embed/upymod/modtrezorio/modtrezorio-poll.h b/core/embed/upymod/modtrezorio/modtrezorio-poll.h index 3e0a738abdb..467429e3f25 100644 --- a/core/embed/upymod/modtrezorio/modtrezorio-poll.h +++ b/core/embed/upymod/modtrezorio/modtrezorio-poll.h @@ -169,15 +169,12 @@ STATIC mp_obj_t mod_trezorio_poll(mp_obj_t ifaces, mp_obj_t list_ref, if ((sectrue == usb_hid_can_read(iface)) || (sectrue == usb_webusb_can_read(iface))) { ret->items[0] = MP_OBJ_NEW_SMALL_INT(i); - ret->items[1] = MP_OBJ_NEW_SMALL_INT(64); + ret->items[1] = MP_OBJ_NEW_SMALL_INT(USB_PACKET_LEN); return mp_const_true; } } else if (mode == POLL_WRITE) { - if (sectrue == usb_hid_can_write(iface)) { - ret->items[0] = MP_OBJ_NEW_SMALL_INT(i); - ret->items[1] = mp_const_none; - return mp_const_true; - } else if (sectrue == usb_webusb_can_write(iface)) { + if ((sectrue == usb_hid_can_write(iface)) || + (sectrue == usb_webusb_can_write(iface))) { ret->items[0] = MP_OBJ_NEW_SMALL_INT(i); ret->items[1] = mp_const_none; return mp_const_true; diff --git a/core/embed/upymod/modtrezorio/modtrezorio-webusb.h b/core/embed/upymod/modtrezorio/modtrezorio-webusb.h index a50e03dded0..301d7b279e6 100644 --- a/core/embed/upymod/modtrezorio/modtrezorio-webusb.h +++ b/core/embed/upymod/modtrezorio/modtrezorio-webusb.h @@ -127,7 +127,7 @@ STATIC mp_obj_t mod_trezorio_WebUSB_write(mp_obj_t self, mp_obj_t msg) { STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorio_WebUSB_write_obj, mod_trezorio_WebUSB_write); -/// def read(self, buf: bytes, offset: int = 0, limit: int | None = None) -> int +/// def read(self, buf: bytes, offset: int = 0) -> int /// """ /// Reads message using USB WebUSB (device) or UDP (emulator). /// """ @@ -141,18 +141,26 @@ STATIC mp_obj_t mod_trezorio_WebUSB_read(size_t n_args, const mp_obj_t *args) { offset = mp_obj_get_int(args[2]); } - int limit; - if (n_args >= 3) { - limit = mp_obj_get_int(args[3]); - } else { - limit = buf.len - offset; + if (offset < 0) { + mp_raise_ValueError("Negative offset not allowed"); + } + + uint32_t buffer_space = buf.len - offset; + + if (buffer_space < USB_PACKET_LEN) { + mp_raise_ValueError("Buffer too small"); + } + + ssize_t r = usb_webusb_read(o->info.iface_num, &((uint8_t *)buf.buf)[offset], + USB_PACKET_LEN); + + if (r != USB_PACKET_LEN) { + mp_raise_msg(&mp_type_RuntimeError, "Unexpected read length"); } - ssize_t r = - usb_webusb_read(o->info.iface_num, &((uint8_t *)buf.buf)[offset], limit); return MP_OBJ_NEW_SMALL_INT(r); } -STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorio_WebUSB_read_obj, 2, 4, +STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorio_WebUSB_read_obj, 2, 3, mod_trezorio_WebUSB_read); STATIC const mp_rom_map_elem_t mod_trezorio_WebUSB_locals_dict_table[] = { diff --git a/core/mocks/generated/trezorio/__init__.pyi b/core/mocks/generated/trezorio/__init__.pyi index 463d7f12359..4d8d3cda797 100644 --- a/core/mocks/generated/trezorio/__init__.pyi +++ b/core/mocks/generated/trezorio/__init__.pyi @@ -32,7 +32,7 @@ class HID: Sends message using USB HID (device) or UDP (emulator). """ - def read(self, buf: bytes, offset: int = 0, limit: int | None = None) -> int + def read(self, buf: bytes, offset: int = 0) -> int """ Reads message using USB HID (device) or UDP (emulator). """ @@ -154,7 +154,7 @@ class WebUSB: Sends message using USB WebUSB (device) or UDP (emulator). """ - def read(self, buf: bytes, offset: int = 0, limit: int | None = None) -> int + def read(self, buf: bytes, offset: int = 0) -> int """ Reads message using USB WebUSB (device) or UDP (emulator). """ diff --git a/core/src/apps/webauthn/fido2.py b/core/src/apps/webauthn/fido2.py index 5ee830ad7d4..b55f56281cb 100644 --- a/core/src/apps/webauthn/fido2.py +++ b/core/src/apps/webauthn/fido2.py @@ -378,9 +378,7 @@ async def _read_cmd(iface: HID) -> Cmd | None: # wait for incoming command indefinitely msg_len = await read buf = bytearray(msg_len) - read_len = iface.read(buf, 0, msg_len) - if read_len != msg_len: - raise ValueError("Invalid length") + iface.read(buf, 0) while True: ifrm = overlay_struct(bytearray(buf), desc_init) bcnt = ifrm.bcnt @@ -421,9 +419,7 @@ async def _read_cmd(iface: HID) -> Cmd | None: try: msg_len = await read buf = bytearray(msg_len) - read_len = iface.read(buf, 0, msg_len) - if read_len != msg_len: - raise ValueError("Invalid length") + iface.read(buf, 0) except loop.Timeout: if __debug__: warning(__name__, "_ERR_MSG_TIMEOUT") diff --git a/core/src/trezor/wire/codec/codec_v1.py b/core/src/trezor/wire/codec/codec_v1.py index 7134222b7c5..e39606be686 100644 --- a/core/src/trezor/wire/codec/codec_v1.py +++ b/core/src/trezor/wire/codec/codec_v1.py @@ -27,10 +27,7 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag # wait for initial report msg_len = await read report = bytearray(msg_len) - read_len = iface.read(report, 0, msg_len) - if read_len != msg_len: - print("read_len", read_len, "msg_len", msg_len) - raise CodecError("Invalid length") + iface.read(report, 0) if report[0] != _REP_MARKER: raise CodecError("Invalid magic") _, magic1, magic2, mtype, msize = ustruct.unpack(_REP_INIT, report) @@ -57,9 +54,7 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag # wait for continuation report msg_len = await read report = bytearray(msg_len) - read_len = iface.read(report, 0, msg_len) - if read_len != msg_len: - raise CodecError("Invalid length") + iface.read(report, 0) if report[0] != _REP_MARKER: raise CodecError("Invalid magic") diff --git a/core/tests/test_trezor.wire.codec.codec_v1.py b/core/tests/test_trezor.wire.codec.codec_v1.py index b4551077aa7..8fcf0add3c7 100644 --- a/core/tests/test_trezor.wire.codec.codec_v1.py +++ b/core/tests/test_trezor.wire.codec.codec_v1.py @@ -25,17 +25,14 @@ def mock_read(self, packet, gen): self.packet = packet return gen.send(len(packet)) - def read(self, buffer, offset=0, limit=None): + def read(self, buffer, offset=0): if self.packet is None: raise Exception("No packet to read") - if limit is None: - limit = len(buffer) - offset - if len(self.packet) > limit: - end = offset + limit - buffer[offset:end] = self.packet[:limit] - self.packet = None - return limit + buffer_space = len(buffer) - offset + + if len(self.packet) > buffer_space: + raise Exception("Buffer too small") else: end = offset + len(self.packet) buffer[offset:end] = self.packet