Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

split polling can_read and reading from USB #4419

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 50 additions & 16 deletions core/embed/io/usb/unix/usb.c
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ static struct {
int sock;
struct sockaddr_in si_me, si_other;
socklen_t slen;
uint8_t msg[64];
int msg_len;
} usb_ifaces[USBD_MAX_NUM_INTERFACES];

secbool usb_init(const usb_dev_info_t *dev_info) {
Expand All @@ -60,7 +62,9 @@ secbool usb_init(const usb_dev_info_t *dev_info) {
usb_ifaces[i].sock = -1;
memzero(&usb_ifaces[i].si_me, sizeof(struct sockaddr_in));
memzero(&usb_ifaces[i].si_other, sizeof(struct sockaddr_in));
memzero(&usb_ifaces[i].msg, sizeof(usb_ifaces[i].msg));
usb_ifaces[i].slen = 0;
usb_ifaces[i].msg_len = 0;
}
return sectrue;
}
Expand Down Expand Up @@ -136,36 +140,66 @@ secbool usb_vcp_add(const usb_vcp_info_t *info) {
return sectrue;
}

static secbool usb_emulated_poll(uint8_t iface_num, short dir) {
static secbool usb_emulated_poll_read(uint8_t iface_num) {
struct pollfd fds[] = {
{usb_ifaces[iface_num].sock, dir, 0},
{usb_ifaces[iface_num].sock, POLLIN, 0},
};
int r = poll(fds, 1, 0);
return sectrue * (r > 0);
}
int res = poll(fds, 1, 0);

if (res <= 0) {
return secfalse;
}

static int usb_emulated_read(uint8_t iface_num, uint8_t *buf, uint32_t len) {
struct sockaddr_in si;
socklen_t sl = sizeof(si);
ssize_t r = recvfrom(usb_ifaces[iface_num].sock, buf, len, MSG_DONTWAIT,
ssize_t r = recvfrom(usb_ifaces[iface_num].sock, usb_ifaces[iface_num].msg,
sizeof(usb_ifaces[iface_num].msg), MSG_DONTWAIT,
(struct sockaddr *)&si, &sl);
if (r < 0) {
return r;
if (r <= 0) {
return secfalse;
}

usb_ifaces[iface_num].si_other = si;
usb_ifaces[iface_num].slen = sl;
static const char *ping_req = "PINGPING";
static const char *ping_resp = "PONGPONG";
if (r == strlen(ping_req) && 0 == memcmp(ping_req, buf, strlen(ping_req))) {
if (r == strlen(ping_req) &&
0 == memcmp(ping_req, usb_ifaces[iface_num].msg, strlen(ping_req))) {
if (usb_ifaces[iface_num].slen > 0) {
sendto(usb_ifaces[iface_num].sock, ping_resp, strlen(ping_resp),
MSG_DONTWAIT,
(const struct sockaddr *)&usb_ifaces[iface_num].si_other,
usb_ifaces[iface_num].slen);
}
return 0;
memzero(usb_ifaces[iface_num].msg, sizeof(usb_ifaces[iface_num].msg));
return secfalse;
}
return r;

usb_ifaces[iface_num].msg_len = r;

return sectrue;
}

static secbool usb_emulated_poll_write(uint8_t iface_num) {
struct pollfd fds[] = {
{usb_ifaces[iface_num].sock, POLLOUT, 0},
};
int r = poll(fds, 1, 0);
return sectrue * (r > 0);
}

static int usb_emulated_read(uint8_t iface_num, uint8_t *buf, uint32_t len) {
if (usb_ifaces[iface_num].msg_len > 0) {
if (usb_ifaces[iface_num].msg_len < len) {
len = usb_ifaces[iface_num].msg_len;
}
memcpy(buf, usb_ifaces[iface_num].msg, len);
usb_ifaces[iface_num].msg_len = 0;
memzero(usb_ifaces[iface_num].msg, sizeof(usb_ifaces[iface_num].msg));
return len;
}

return 0;
}

static int usb_emulated_write(uint8_t iface_num, const uint8_t *buf,
Expand All @@ -184,31 +218,31 @@ secbool usb_hid_can_read(uint8_t iface_num) {
usb_ifaces[iface_num].type != USB_IFACE_TYPE_HID) {
return secfalse;
}
return usb_emulated_poll(iface_num, POLLIN);
return usb_emulated_poll_read(iface_num);
}

secbool usb_webusb_can_read(uint8_t iface_num) {
if (iface_num >= USBD_MAX_NUM_INTERFACES ||
usb_ifaces[iface_num].type != USB_IFACE_TYPE_WEBUSB) {
return secfalse;
}
return usb_emulated_poll(iface_num, POLLIN);
return usb_emulated_poll_read(iface_num);
}

secbool usb_hid_can_write(uint8_t iface_num) {
if (iface_num >= USBD_MAX_NUM_INTERFACES ||
usb_ifaces[iface_num].type != USB_IFACE_TYPE_HID) {
return secfalse;
}
return usb_emulated_poll(iface_num, POLLOUT);
return usb_emulated_poll_write(iface_num);
}

secbool usb_webusb_can_write(uint8_t iface_num) {
if (iface_num >= USBD_MAX_NUM_INTERFACES ||
usb_ifaces[iface_num].type != USB_IFACE_TYPE_WEBUSB) {
return secfalse;
}
return usb_emulated_poll(iface_num, POLLOUT);
return usb_emulated_poll_write(iface_num);
}

int usb_hid_read(uint8_t iface_num, uint8_t *buf, uint32_t len) {
Expand Down
29 changes: 29 additions & 0 deletions core/embed/upymod/modtrezorio/modtrezorio-hid.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,34 @@ STATIC mp_obj_t mod_trezorio_HID_write(mp_obj_t self, mp_obj_t msg) {
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorio_HID_write_obj,
mod_trezorio_HID_write);

/// def read(self, buf: bytes, offset: int = 0, limit: int | None = None) -> int
/// """
/// Reads message using USB HID (device) or UDP (emulator).
/// """
STATIC mp_obj_t mod_trezorio_HID_read(size_t n_args, const mp_obj_t *args) {
mp_obj_HID_t *o = MP_OBJ_TO_PTR(args[0]);
mp_buffer_info_t buf = {0};
mp_get_buffer_raise(args[1], &buf, MP_BUFFER_WRITE);

int offset = 0;
if (n_args >= 2) {
offset = mp_obj_get_int(args[2]);
}

int limit;
if (n_args >= 3) {
limit = mp_obj_get_int(args[3]);
} else {
limit = buf.len - offset;
}

ssize_t r =
usb_hid_read(o->info.iface_num, &((uint8_t *)buf.buf)[offset], limit);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we also need to check that offset and limit are valid for the buf
(and ideally raise a ValueError they exceed)
dtto webusb

Copy link
Contributor Author

@TychoVrahe TychoVrahe Dec 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

solved as part of 7d9a53c, see below

return MP_OBJ_NEW_SMALL_INT(r);
}
STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorio_HID_read_obj, 3, 4,
mod_trezorio_HID_read);

/// def write_blocking(self, msg: bytes, timeout_ms: int) -> int:
/// """
/// Sends message using USB HID (device) or UDP (emulator).
Expand All @@ -162,6 +190,7 @@ STATIC const mp_rom_map_elem_t mod_trezorio_HID_locals_dict_table[] = {
{MP_ROM_QSTR(MP_QSTR_iface_num),
MP_ROM_PTR(&mod_trezorio_HID_iface_num_obj)},
{MP_ROM_QSTR(MP_QSTR_write), MP_ROM_PTR(&mod_trezorio_HID_write_obj)},
{MP_ROM_QSTR(MP_QSTR_read), MP_ROM_PTR(&mod_trezorio_HID_read_obj)},
{MP_ROM_QSTR(MP_QSTR_write_blocking),
MP_ROM_PTR(&mod_trezorio_HID_write_blocking_obj)},
};
Expand Down
21 changes: 5 additions & 16 deletions core/embed/upymod/modtrezorio/modtrezorio-poll.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,22 +166,11 @@ STATIC mp_obj_t mod_trezorio_poll(mp_obj_t ifaces, mp_obj_t list_ref,
}
#endif
else if (mode == POLL_READ) {
if (sectrue == usb_hid_can_read(iface)) {
uint8_t buf[64] = {0};
int len = usb_hid_read(iface, buf, sizeof(buf));
if (len > 0) {
ret->items[0] = MP_OBJ_NEW_SMALL_INT(i);
ret->items[1] = mp_obj_new_bytes(buf, len);
return mp_const_true;
}
} else if (sectrue == usb_webusb_can_read(iface)) {
uint8_t buf[64] = {0};
int len = usb_webusb_read(iface, buf, sizeof(buf));
if (len > 0) {
ret->items[0] = MP_OBJ_NEW_SMALL_INT(i);
ret->items[1] = mp_obj_new_bytes(buf, len);
return mp_const_true;
}
if ((sectrue == usb_hid_can_read(iface)) ||
(sectrue == usb_webusb_can_read(iface))) {
ret->items[0] = MP_OBJ_NEW_SMALL_INT(i);
ret->items[1] = MP_OBJ_NEW_SMALL_INT(64);
return mp_const_true;
}
} else if (mode == POLL_WRITE) {
if (sectrue == usb_hid_can_write(iface)) {
Expand Down
29 changes: 29 additions & 0 deletions core/embed/upymod/modtrezorio/modtrezorio-webusb.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,39 @@ STATIC mp_obj_t mod_trezorio_WebUSB_write(mp_obj_t self, mp_obj_t msg) {
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorio_WebUSB_write_obj,
mod_trezorio_WebUSB_write);

/// def read(self, buf: bytes, offset: int = 0, limit: int | None = None) -> int
/// """
/// Reads message using USB WebUSB (device) or UDP (emulator).
/// """
STATIC mp_obj_t mod_trezorio_WebUSB_read(size_t n_args, const mp_obj_t *args) {
mp_obj_WebUSB_t *o = MP_OBJ_TO_PTR(args[0]);
mp_buffer_info_t buf = {0};
mp_get_buffer_raise(args[1], &buf, MP_BUFFER_WRITE);

int offset = 0;
if (n_args >= 2) {
offset = mp_obj_get_int(args[2]);
}

int limit;
if (n_args >= 3) {
limit = mp_obj_get_int(args[3]);
} else {
limit = buf.len - offset;
}

ssize_t r =
usb_webusb_read(o->info.iface_num, &((uint8_t *)buf.buf)[offset], limit);
return MP_OBJ_NEW_SMALL_INT(r);
}
STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorio_WebUSB_read_obj, 2, 4,
mod_trezorio_WebUSB_read);

STATIC const mp_rom_map_elem_t mod_trezorio_WebUSB_locals_dict_table[] = {
{MP_ROM_QSTR(MP_QSTR_iface_num),
MP_ROM_PTR(&mod_trezorio_WebUSB_iface_num_obj)},
{MP_ROM_QSTR(MP_QSTR_write), MP_ROM_PTR(&mod_trezorio_WebUSB_write_obj)},
{MP_ROM_QSTR(MP_QSTR_read), MP_ROM_PTR(&mod_trezorio_WebUSB_read_obj)},
};
STATIC MP_DEFINE_CONST_DICT(mod_trezorio_WebUSB_locals_dict,
mod_trezorio_WebUSB_locals_dict_table);
Expand Down
10 changes: 10 additions & 0 deletions core/mocks/generated/trezorio/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ class HID:
Sends message using USB HID (device) or UDP (emulator).
"""

def read(self, buf: bytes, offset: int = 0, limit: int | None = None) -> int
"""
Reads message using USB HID (device) or UDP (emulator).
"""

def write_blocking(self, msg: bytes, timeout_ms: int) -> int:
"""
Sends message using USB HID (device) or UDP (emulator).
Expand Down Expand Up @@ -148,6 +153,11 @@ class WebUSB:
"""
Sends message using USB WebUSB (device) or UDP (emulator).
"""

def read(self, buf: bytes, offset: int = 0, limit: int | None = None) -> int
"""
Reads message using USB WebUSB (device) or UDP (emulator).
"""
from . import fatfs, haptic, sdcard
POLL_READ: int # wait until interface is readable and return read data
POLL_WRITE: int # wait until interface is writable
Expand Down
12 changes: 10 additions & 2 deletions core/src/apps/webauthn/fido2.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,11 @@ async def _read_cmd(iface: HID) -> Cmd | None:
read = loop.wait(iface.iface_num() | io.POLL_READ)

# wait for incoming command indefinitely
buf = await read
msg_len = await read
buf = bytearray(msg_len)
read_len = iface.read(buf, 0, msg_len)
if read_len != msg_len:
raise ValueError("Invalid length")
while True:
ifrm = overlay_struct(bytearray(buf), desc_init)
bcnt = ifrm.bcnt
Expand Down Expand Up @@ -415,7 +419,11 @@ async def _read_cmd(iface: HID) -> Cmd | None:
read.timeout_ms = _CTAP_HID_TIMEOUT_MS
while datalen < bcnt:
try:
buf = await read
msg_len = await read
buf = bytearray(msg_len)
read_len = iface.read(buf, 0, msg_len)
if read_len != msg_len:
raise ValueError("Invalid length")
except loop.Timeout:
if __debug__:
warning(__name__, "_ERR_MSG_TIMEOUT")
Expand Down
13 changes: 11 additions & 2 deletions core/src/trezor/wire/codec/codec_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag
read = loop.wait(iface.iface_num() | io.POLL_READ)

# wait for initial report
report = await read
msg_len = await read
report = bytearray(msg_len)
read_len = iface.read(report, 0, msg_len)
if read_len != msg_len:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these checks are kind of annoying... could we get rid of them somehow?
on usb we don't really need to know msg_len because it is fixed, so i'm guessing this is for BT with variable length packets?
perhaps read() should guarantee that it reads full limit (or buffer size) and exception out otherwise?
or, can we guarantee that if we got msg_len out of await read, the next read must return that many bytes? (i mean we could always do the full read into a buffer like we do on unix, i suppose...)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about this: 7d9a53c

  • the guarantee that msg_len is equal to real read data is implemented
  • exception if buffer space is insufficient
  • exception if for any reason the underlying USB implementation reads message of different length, but modification of the buffer is allowed under such circumstances
  • limit is removed, assumed that you always read the whole message of the msg_len size. (or do we ever want to allow reading partials, even if it means losing the rest of the data?)

print("read_len", read_len, "msg_len", msg_len)
raise CodecError("Invalid length")
if report[0] != _REP_MARKER:
raise CodecError("Invalid magic")
_, magic1, magic2, mtype, msize = ustruct.unpack(_REP_INIT, report)
Expand All @@ -50,7 +55,11 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag

while nread < msize:
# wait for continuation report
report = await read
msg_len = await read
report = bytearray(msg_len)
read_len = iface.read(report, 0, msg_len)
if read_len != msg_len:
raise CodecError("Invalid length")
if report[0] != _REP_MARKER:
raise CodecError("Invalid magic")

Expand Down
Loading
Loading