Skip to content

Commit

Permalink
improve blocking receive
Browse files Browse the repository at this point in the history
  • Loading branch information
deanlee committed Aug 13, 2024
1 parent 23cb05a commit d36d04c
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 63 deletions.
5 changes: 3 additions & 2 deletions SConscript
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ msgq_objects = env.SharedObject([
'msgq/msgq.cc',
])
msgq = env.Library('msgq', msgq_objects)
msgq_python = envCython.Program('msgq/ipc_pyx.so', 'msgq/ipc_pyx.pyx', LIBS=envCython["LIBS"]+[msgq, "zmq", common])
msgq_python = envCython.Program('msgq/ipc_pyx.so', 'msgq/ipc_pyx.pyx', LIBS=envCython["LIBS"]+[msgq, "zmq", 'pthread',common])

# Build Vision IPC
vipc_files = ['visionipc.cc', 'visionipc_server.cc', 'visionipc_client.cc', 'visionbuf.cc']
Expand All @@ -31,7 +31,7 @@ visionipc = env.Library('visionipc', vipc_objects)


vipc_frameworks = []
vipc_libs = envCython["LIBS"] + [visionipc, msgq, common, "zmq"]
vipc_libs = envCython["LIBS"] + [visionipc, msgq, common, "zmq", 'pthread']
if arch == "Darwin":
vipc_frameworks.append('OpenCL')
else:
Expand All @@ -45,4 +45,5 @@ if GetOption('extras'):
[f'{visionipc_dir.abspath}/test_runner.cc', f'{visionipc_dir.abspath}/visionipc_tests.cc'],
LIBS=['pthread'] + vipc_libs, FRAMEWORKS=vipc_frameworks)

msgq = [msgq, 'pthread']
Export('visionipc', 'msgq', 'msgq_python')
109 changes: 50 additions & 59 deletions msgq/impl_msgq.cc
Original file line number Diff line number Diff line change
@@ -1,20 +1,12 @@
#include <cassert>
#include <cstring>
#include <iostream>
#include <cstdlib>
#include <chrono>
#include <csignal>
#include <cerrno>

#include "msgq/impl_msgq.h"


volatile sig_atomic_t msgq_do_exit = 0;

void sig_handler(int signal) {
assert(signal == SIGINT || signal == SIGTERM);
msgq_do_exit = 1;
}

using namespace std::chrono;

MSGQContext::MSGQContext() {
}
Expand Down Expand Up @@ -70,61 +62,60 @@ int MSGQSubSocket::connect(Context *context, std::string endpoint, std::string a
return 0;
}


Message * MSGQSubSocket::receive(bool non_blocking){
msgq_do_exit = 0;

void (*prev_handler_sigint)(int);
void (*prev_handler_sigterm)(int);
if (!non_blocking){
prev_handler_sigint = std::signal(SIGINT, sig_handler);
prev_handler_sigterm = std::signal(SIGTERM, sig_handler);
}

msgq_msg_t msg;

MSGQMessage *r = NULL;

Message *MSGQSubSocket::receive(bool non_blocking) {
msgq_msg_t msg{};
int rc = msgq_msg_recv(&msg, q);

// Hack to implement blocking read with a poller. Don't use this
while (!non_blocking && rc == 0 && msgq_do_exit == 0){
msgq_pollitem_t items[1];
items[0].q = q;

int t = (timeout != -1) ? timeout : 100;

int n = msgq_poll(items, 1, t);
rc = msgq_msg_recv(&msg, q);

// The poll indicated a message was ready, but the receive failed. Try again
if (n == 1 && rc == 0){
continue;
}

if (timeout != -1){
break;
if (rc == 0 && !non_blocking) {
sigset_t mask;
sigset_t old_mask;
sigemptyset(&mask);
sigaddset(&mask, SIGINT);
sigaddset(&mask, SIGTERM);
sigaddset(&mask, SIGUSR2);

pthread_sigmask(SIG_BLOCK, &mask, &old_mask);
// sigprocmask(SIG_BLOCK, &mask, nullptr);

// Set timeout, default is 100 ms
int64_t timieout_ns = ((timeout != -1) ? timeout : 2000) * 1000000;
auto start = steady_clock::now();

// Continue receiving messages until timeout or interruption by SIGINT or SIGTERM
while (rc == 0 && timieout_ns > 0) {
struct timespec ts {
timieout_ns / 1000000000,
timieout_ns % 1000000000,
};

int ret = sigtimedwait(&mask, nullptr, &ts);
printf("%d\n", ret);
if (ret == SIGINT || ret == SIGTERM) {
// Ensure signal handling is not missed
raise(ret);
break;
} else if (ret == -1 && errno == EAGAIN && timeout != -1) {
break; // Timed out
} else {
}

rc = msgq_msg_recv(&msg, q);

if (timeout != -1) {
timieout_ns -= duration_cast<nanoseconds>(steady_clock::now() - start).count();
start = steady_clock::now(); // Update start time
}
}
pthread_sigmask(SIG_SETMASK, &old_mask, nullptr);
// sigprocmask(SIG_UNBLOCK, &mask, nullptr);
}


if (!non_blocking){
std::signal(SIGINT, prev_handler_sigint);
std::signal(SIGTERM, prev_handler_sigterm);
}

errno = msgq_do_exit ? EINTR : 0;

if (rc > 0){
if (msgq_do_exit){
msgq_msg_close(&msg); // Free unused message on exit
} else {
r = new MSGQMessage;
r->takeOwnership(msg.data, msg.size);
}
if (rc > 0) {
MSGQMessage *r = new MSGQMessage;
r->takeOwnership(msg.data, msg.size);
return r;
}

return (Message*)r;
return nullptr;
}

void MSGQSubSocket::setTimeout(int t){
Expand Down
2 changes: 1 addition & 1 deletion msgq/ipc.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ cdef extern from "msgq/ipc.h":
@staticmethod
SubSocket * create()
int connect(Context *, string, string, bool)
Message * receive(bool)
Message * receive(bool) nogil
void setTimeout(int)

cdef cppclass PubSocket:
Expand Down
4 changes: 3 additions & 1 deletion msgq/ipc_pyx.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,9 @@ cdef class SubSocket:
self.socket.setTimeout(timeout)

def receive(self, bool non_blocking=False):
msg = self.socket.receive(non_blocking)
cdef cppMessage *msg
with nogil:
msg = self.socket.receive(non_blocking)

if msg == NULL:
# If a blocking read returns no message check errno if SIGINT was caught in the C++ code
Expand Down
25 changes: 25 additions & 0 deletions msgq/tests/test_messaging.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import os
import pytest
import random
import signal
import threading
import time
import string
import msgq
Expand Down Expand Up @@ -67,3 +70,25 @@ def test_receive_timeout(self):
recvd = sub_sock.receive()
assert (time.monotonic() - start_time) < 0.2
assert recvd is None

def test_receive_interrupts_on_sigint(self):
sock = random_sock()
sub_sock = msgq.sub_sock(sock)

pid = os.getpid()
# Send SIGINT after a short delay
def send_sigint():
time.sleep(.5)
os.kill(pid, signal.SIGINT)

# Start a thread to send SIGINT
thread = threading.Thread(target=send_sigint)
thread.start()

with pytest.raises(KeyboardInterrupt):
start_time = time.monotonic()
recvd = sub_sock.receive()
assert (time.monotonic() - start_time) < 0.5
assert recvd is None

thread.join()

0 comments on commit d36d04c

Please sign in to comment.