Skip to content

Commit

Permalink
refactor(core): split polling can_read and reading from USB
Browse files Browse the repository at this point in the history
[no changelog]
  • Loading branch information
TychoVrahe committed Dec 3, 2024
1 parent 13df961 commit e942f8e
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 28 deletions.
15 changes: 15 additions & 0 deletions core/embed/upymod/modtrezorio/modtrezorio-hid.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,20 @@ STATIC mp_obj_t mod_trezorio_HID_write(mp_obj_t self, mp_obj_t msg) {
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorio_HID_write_obj,
mod_trezorio_HID_write);

/// def read(self, buf: bytes) -> int:
/// """
/// Reads message using USB HID (device) or UDP (emulator).
/// """
STATIC mp_obj_t mod_trezorio_HID_read(mp_obj_t self, mp_obj_t buffer) {
mp_obj_HID_t *o = MP_OBJ_TO_PTR(self);
mp_buffer_info_t buf = {0};
mp_get_buffer_raise(buffer, &buf, MP_BUFFER_WRITE);
ssize_t r = usb_hid_read(o->info.iface_num, buf.buf, buf.len);
return MP_OBJ_NEW_SMALL_INT(r);
}
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorio_HID_read_obj,
mod_trezorio_HID_read);

/// def write_blocking(self, msg: bytes, timeout_ms: int) -> int:
/// """
/// Sends message using USB HID (device) or UDP (emulator).
Expand All @@ -162,6 +176,7 @@ STATIC const mp_rom_map_elem_t mod_trezorio_HID_locals_dict_table[] = {
{MP_ROM_QSTR(MP_QSTR_iface_num),
MP_ROM_PTR(&mod_trezorio_HID_iface_num_obj)},
{MP_ROM_QSTR(MP_QSTR_write), MP_ROM_PTR(&mod_trezorio_HID_write_obj)},
{MP_ROM_QSTR(MP_QSTR_read), MP_ROM_PTR(&mod_trezorio_HID_read_obj)},
{MP_ROM_QSTR(MP_QSTR_write_blocking),
MP_ROM_PTR(&mod_trezorio_HID_write_blocking_obj)},
};
Expand Down
21 changes: 5 additions & 16 deletions core/embed/upymod/modtrezorio/modtrezorio-poll.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,22 +166,11 @@ STATIC mp_obj_t mod_trezorio_poll(mp_obj_t ifaces, mp_obj_t list_ref,
}
#endif
else if (mode == POLL_READ) {
if (sectrue == usb_hid_can_read(iface)) {
uint8_t buf[64] = {0};
int len = usb_hid_read(iface, buf, sizeof(buf));
if (len > 0) {
ret->items[0] = MP_OBJ_NEW_SMALL_INT(i);
ret->items[1] = mp_obj_new_bytes(buf, len);
return mp_const_true;
}
} else if (sectrue == usb_webusb_can_read(iface)) {
uint8_t buf[64] = {0};
int len = usb_webusb_read(iface, buf, sizeof(buf));
if (len > 0) {
ret->items[0] = MP_OBJ_NEW_SMALL_INT(i);
ret->items[1] = mp_obj_new_bytes(buf, len);
return mp_const_true;
}
if ((sectrue == usb_hid_can_read(iface)) ||
(sectrue == usb_webusb_can_read(iface))) {
ret->items[0] = MP_OBJ_NEW_SMALL_INT(i);
ret->items[1] = MP_OBJ_NEW_SMALL_INT(64);
return mp_const_true;
}
} else if (mode == POLL_WRITE) {
if (sectrue == usb_hid_can_write(iface)) {
Expand Down
15 changes: 15 additions & 0 deletions core/embed/upymod/modtrezorio/modtrezorio-webusb.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,25 @@ STATIC mp_obj_t mod_trezorio_WebUSB_write(mp_obj_t self, mp_obj_t msg) {
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorio_WebUSB_write_obj,
mod_trezorio_WebUSB_write);

/// def read(self, buf: bytes) -> int:
/// """
/// Reads message using WebUSB (device) or UDP (emulator).
/// """
STATIC mp_obj_t mod_trezorio_WebUSB_read(mp_obj_t self, mp_obj_t buffer) {
mp_obj_HID_t *o = MP_OBJ_TO_PTR(self);
mp_buffer_info_t buf = {0};
mp_get_buffer_raise(buffer, &buf, MP_BUFFER_WRITE);
ssize_t r = usb_webusb_read(o->info.iface_num, buf.buf, buf.len);
return MP_OBJ_NEW_SMALL_INT(r);
}
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorio_WebUSB_read_obj,
mod_trezorio_WebUSB_read);

STATIC const mp_rom_map_elem_t mod_trezorio_WebUSB_locals_dict_table[] = {
{MP_ROM_QSTR(MP_QSTR_iface_num),
MP_ROM_PTR(&mod_trezorio_WebUSB_iface_num_obj)},
{MP_ROM_QSTR(MP_QSTR_write), MP_ROM_PTR(&mod_trezorio_WebUSB_write_obj)},
{MP_ROM_QSTR(MP_QSTR_read), MP_ROM_PTR(&mod_trezorio_WebUSB_read_obj)},
};
STATIC MP_DEFINE_CONST_DICT(mod_trezorio_WebUSB_locals_dict,
mod_trezorio_WebUSB_locals_dict_table);
Expand Down
10 changes: 10 additions & 0 deletions core/mocks/generated/trezorio/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ class HID:
Sends message using USB HID (device) or UDP (emulator).
"""

def read(self, buf: bytes) -> int:
"""
Reads message using USB HID (device) or UDP (emulator).
"""

def write_blocking(self, msg: bytes, timeout_ms: int) -> int:
"""
Sends message using USB HID (device) or UDP (emulator).
Expand Down Expand Up @@ -148,6 +153,11 @@ class WebUSB:
"""
Sends message using USB WebUSB (device) or UDP (emulator).
"""

def read(self, buf: bytes) -> int:
"""
Reads message using WebUSB (device) or UDP (emulator).
"""
from . import fatfs, haptic, sdcard
POLL_READ: int # wait until interface is readable and return read data
POLL_WRITE: int # wait until interface is writable
Expand Down
12 changes: 10 additions & 2 deletions core/src/apps/webauthn/fido2.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,11 @@ async def _read_cmd(iface: HID) -> Cmd | None:
read = loop.wait(iface.iface_num() | io.POLL_READ)

# wait for incoming command indefinitely
buf = await read
msg_len = await read
buf = bytearray(msg_len)
read_len = iface.read(buf)
if read_len != msg_len:
raise ValueError("Invalid length")
while True:
ifrm = overlay_struct(bytearray(buf), desc_init)
bcnt = ifrm.bcnt
Expand Down Expand Up @@ -415,7 +419,11 @@ async def _read_cmd(iface: HID) -> Cmd | None:
read.timeout_ms = _CTAP_HID_TIMEOUT_MS
while datalen < bcnt:
try:
buf = await read
msg_len = await read
buf = bytearray(msg_len)
read_len = iface.read(buf)
if read_len != msg_len:
raise ValueError("Invalid length")
except loop.Timeout:
if __debug__:
warning(__name__, "_ERR_MSG_TIMEOUT")
Expand Down
12 changes: 10 additions & 2 deletions core/src/trezor/wire/codec/codec_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag
read = loop.wait(iface.iface_num() | io.POLL_READ)

# wait for initial report
report = await read
msg_len = await read
report = bytearray(msg_len)
read_len = iface.read(report)
if read_len != msg_len:
raise CodecError("Invalid length")
if report[0] != _REP_MARKER:
raise CodecError("Invalid magic")
_, magic1, magic2, mtype, msize = ustruct.unpack(_REP_INIT, report)
Expand All @@ -50,7 +54,11 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag

while nread < msize:
# wait for continuation report
report = await read
msg_len = await read
report = bytearray(msg_len)
read_len = iface.read(report)
if read_len != msg_len:
raise CodecError("Invalid length")
if report[0] != _REP_MARKER:
raise CodecError("Invalid magic")

Expand Down
28 changes: 20 additions & 8 deletions core/tests/test_trezor.wire.codec.codec_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class MockHID:
def __init__(self, num):
self.num = num
self.data = []
self.packet = None

def iface_num(self):
return self.num
Expand All @@ -20,6 +21,17 @@ def write(self, msg):
self.data.append(bytearray(msg))
return len(msg)

def mock_read(self, packet, gen):
self.packet = packet
return gen.send(len(packet))

def read(self, buffer):
if self.packet is None:
raise Exception("No packet to read")
buffer[:] = self.packet
self.packet = None
return len(buffer)

def wait_object(self, mode):
return wait(mode | self.num)

Expand Down Expand Up @@ -48,7 +60,7 @@ def test_read_one_packet(self):
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))

with self.assertRaises(StopIteration) as e:
gen.send(message_packet)
self.interface.mock_read(message_packet, gen)

# e.value is StopIteration. e.value.value is the return value of the call
result = e.value.value
Expand All @@ -74,11 +86,11 @@ def test_read_many_packets(self):
query = gen.send(None)
for packet in packets[:-1]:
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
query = gen.send(packet)
query = self.interface.mock_read(packet, gen)

# last packet will stop
with self.assertRaises(StopIteration) as e:
gen.send(packets[-1])
self.interface.mock_read(packets[-1], gen)

# e.value is StopIteration. e.value.value is the return value of the call
result = e.value.value
Expand All @@ -103,7 +115,7 @@ def test_read_large_message(self):
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
with self.assertRaises(StopIteration) as e:
gen.send(packet)
self.interface.mock_read(packet, gen)

# e.value is StopIteration. e.value.value is the return value of the call
result = e.value.value
Expand Down Expand Up @@ -169,10 +181,10 @@ def test_roundtrip(self):
query = gen.send(None)
for packet in self.interface.data[:-1]:
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
query = gen.send(packet)
query = self.interface.mock_read(packet, gen)

with self.assertRaises(StopIteration) as e:
gen.send(self.interface.data[-1])
self.interface.mock_read(self.interface.data[-1], gen)

result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE)
Expand All @@ -194,10 +206,10 @@ def test_read_huge_packet(self):
query = gen.send(None)
for _ in range(PACKET_COUNT - 1):
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
query = gen.send(packet)
query = self.interface.mock_read(packet, gen)

with self.assertRaises(codec_v1.CodecError) as e:
gen.send(packet)
self.interface.mock_read(packet,gen)

self.assertEqual(e.value.args[0], "Message too large")

Expand Down

0 comments on commit e942f8e

Please sign in to comment.