Skip to content

Commit

Permalink
mission protocol: added lots of unit testing and protocol hardeing
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasDebrunner committed Jan 22, 2024
1 parent 8717acf commit 4f6a688
Show file tree
Hide file tree
Showing 3 changed files with 436 additions and 52 deletions.
64 changes: 42 additions & 22 deletions include/mav/opinionated_protocols/MissionClient.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ namespace mission {
std::shared_ptr<mav::Connection> _connection;
const MessageSet& _message_set;

void _assertNotNack(const Message& message) {
if (message.id() == _message_set.idForMessage("MISSION_ACK")) {
if (message["type"].as<uint64_t>() != _message_set.e("MAV_MISSION_ACCEPTED")) {
throw mav::ProtocolException("Received NACK from server. Mission transaction failed.");
}
}
}

public:
MissionClient(std::shared_ptr<Connection> connection, const MessageSet& message_set) :
_connection(std::move(connection)),
Expand All @@ -44,15 +52,17 @@ namespace mission {
int mission_type = mission_messages[0]["mission_type"];

// Send mission count
auto mission_count_message = _message_set.create("MISSION_COUNT").set({
auto mission_count_message = _message_set.create("MISSION_COUNT")({
{"target_system", target.system_id},
{"target_component", target.component_id},
{"count", mission_messages.size()},
{"mission_type", mission_type}});

auto count_response = exchangeRetry(_connection, mission_count_message, "MISSION_REQUEST_INT",
target.system_id, target.component_id, retry_count, item_timeout);

Message count_response = exchangeRetryAnyResponse(_connection, _message_set, mission_count_message,
{"MISSION_ACK", "MISSION_REQUEST_INT"},
target.system_id, target.component_id,
retry_count, item_timeout);
_assertNotNack(count_response);
throwAssert(count_response["mission_type"].as<int>() == mission_type, "Mission type mismatch");
throwAssert(count_response["seq"].as<int>() == 0, "Sequence number mismatch");

Expand All @@ -63,54 +73,65 @@ namespace mission {
mission_item_message["target_component"] = target.component_id;
mission_item_message["seq"] = seq;

// we expect an ack for the last message, otherwise a request for the next one
const auto expected_response = seq == static_cast<int>(mission_messages.size()) - 1 ?
"MISSION_ACK" : "MISSION_REQUEST_INT";
auto item_response = exchangeRetry(_connection, mission_item_message, expected_response,
auto item_response = exchangeRetryAnyResponse(_connection, _message_set, mission_item_message,
{"MISSION_ACK", "MISSION_REQUEST_INT"},
target.system_id, target.component_id, retry_count, item_timeout);
// NACK is always bad, throw if we receive NACK
_assertNotNack(item_response);

if (seq == static_cast<int>(mission_messages.size()) - 1) {
// we expect an ack for the last message
throwAssert(item_response["type"].as<int>() == _message_set.e("MAV_MISSION_ACCEPTED"), "Mission upload failed");
seq++;
} else {
// we expect a request for the next message
throwAssert(item_response["mission_type"].as<int>() == mission_type, "Mission type mismatch");
seq = item_response["seq"];
// we're okay with an ack, only when we're at the last message
if (item_response.id() == _message_set.idForMessage("MISSION_ACK")) {
break;
}
}
// in general, we need a mission request int
throwAssert(item_response.id() == _message_set.idForMessage("MISSION_REQUEST_INT"), "Unexpected message");
// we expect a request for the next message
throwAssert(item_response["mission_type"].as<int>() == mission_type, "Mission type mismatch");
int response_seq = item_response["seq"];
// we allow requests for the next message or the current message
throwAssert((response_seq == seq + 1 || response_seq == seq)
&& response_seq < static_cast<int>(mission_messages.size()), "Sequence number mismatch");
seq = response_seq;
}
}

std::vector<Message> download(Identifier target={1, 1}, int mission_type=0, int retry_count=3, int item_timeout=1000) {
// Send mission request list
auto mission_request_list_message = _message_set.create("MISSION_REQUEST_LIST").set({
auto mission_request_list_message = _message_set.create("MISSION_REQUEST_LIST")({
{"target_system", target.system_id},
{"target_component", target.component_id},
{"mission_type", mission_type}});

auto request_list_response = exchangeRetry(_connection, mission_request_list_message, "MISSION_COUNT",
auto request_list_response = exchangeRetryAnyResponse(_connection, _message_set,
mission_request_list_message, {"MISSION_COUNT", "MISSION_ACK"},
target.system_id, target.component_id, retry_count, item_timeout);

_assertNotNack(request_list_response);
throwAssert(request_list_response.id() == _message_set.idForMessage("MISSION_COUNT"), "Unexpected message");
throwAssert(request_list_response["mission_type"].as<int>() == mission_type, "Mission type mismatch");

int count = request_list_response["count"];
std::vector<Message> mission_messages;
for (int seq = 0; seq < count; seq++) {
auto mission_request_message = _message_set.create("MISSION_REQUEST_INT").set({
auto mission_request_message = _message_set.create("MISSION_REQUEST_INT")({
{"target_system", target.system_id},
{"target_component", target.component_id},
{"seq", seq},
{"mission_type", 0}});

auto request_response = exchangeRetry(_connection, mission_request_message, "MISSION_ITEM_INT",
auto request_response = exchangeRetryAnyResponse(_connection, _message_set, mission_request_message,
{"MISSION_ITEM_INT", "MISSION_ACK"},
target.system_id, target.component_id, retry_count, item_timeout);

_assertNotNack(request_response);
throwAssert(request_response.id() == _message_set.idForMessage("MISSION_ITEM_INT"), "Unexpected message");
throwAssert(request_response["mission_type"].as<int>() == 0, "Mission type mismatch");
throwAssert(request_response["seq"].as<int>() == seq, "Sequence number mismatch");

mission_messages.push_back(request_response);
}
auto ack_message = _message_set.create("MISSION_ACK").set({
auto ack_message = _message_set.create("MISSION_ACK")({
{"target_system", target.system_id},
{"target_component", target.component_id},
{"type", _message_set.e("MAV_MISSION_ACCEPTED")},
Expand All @@ -123,5 +144,4 @@ namespace mission {
};
}


#endif //MAV_MISSIONCLIENT_H
74 changes: 63 additions & 11 deletions include/mav/opinionated_protocols/ProtocolUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@

namespace mav {

class ProtocolError : public std::runtime_error {
class ProtocolException : public std::runtime_error {
public:
ProtocolError(const std::string &message) : std::runtime_error(message) {}
ProtocolException(const std::string &message) : std::runtime_error(message) {}
};

inline void ensureMessageInMessageSet(const MessageSet &message_set, const std::initializer_list<std::string> &message_names) {
Expand All @@ -27,12 +27,30 @@ namespace mav {

inline void throwAssert(bool condition, const std::string& message) {
if (!condition) {
throw ProtocolError(message);
throw ProtocolException(message);
}
}

inline Connection::Expectation expectAny(const std::shared_ptr<mav::Connection>& connection,
const std::vector<int> &message_ids, int source_id=mav::ANY_ID, int component_id=mav::ANY_ID) {
return connection->expect([message_ids, source_id, component_id](const Message& message) {
return std::find(message_ids.begin(), message_ids.end(), message.id()) != message_ids.end() &&
(source_id == mav::ANY_ID || message.header().systemId() == source_id) &&
(component_id == mav::ANY_ID || message.header().componentId() == component_id);
});
}

inline Connection::Expectation expectAny(const std::shared_ptr<mav::Connection>& connection, const MessageSet &message_set,
const std::vector<std::string> &message_names, int source_id=mav::ANY_ID, int component_id=mav::ANY_ID) {
std::vector<int> message_ids;
for (const auto &message_name : message_names) {
message_ids.push_back(message_set.idForMessage(message_name));
}
return expectAny(connection, message_ids, source_id, component_id);
}

inline Message exchange(
std::shared_ptr<mav::Connection> connection,
const std::shared_ptr<mav::Connection> &connection,
Message &request,
const std::string &response_message_name,
int source_id = mav::ANY_ID,
Expand All @@ -43,25 +61,59 @@ namespace mav {
return connection->receive(expectation, timeout_ms);
}

inline Message exchangeRetry(
const std::shared_ptr<mav::Connection>& connection,
inline Message exchangeAnyResponse(
const std::shared_ptr<mav::Connection> &connection,
const MessageSet &message_set,
Message &request,
const std::string &response_message_name,
const std::vector<std::string> &response_message_names,
int source_id = mav::ANY_ID,
int source_component = mav::ANY_ID,
int retries = 3,
int timeout_ms = 1000) {
auto expectation = expectAny(connection, message_set, response_message_names, source_id, source_component);
connection->send(request);
return connection->receive(expectation, timeout_ms);
}

template <typename Ret, typename... Arg>
inline Ret _retry(int retries, Ret (*func)(Arg...), Arg... args) {
for (int i = 0; i < retries; i++) {
try {
return exchange(connection, request, response_message_name, source_id, source_component, timeout_ms);
return func(args...);
} catch (const TimeoutException& e) {
if (i == retries - 1) {
throw e;
}
}
}
throw ProtocolError("Exchange of message " + request.name() + " -> " + response_message_name +
" failed after " + std::to_string(retries) + " retries");
throw ProtocolException("Function failed after " + std::to_string(retries) + " retries");
}

inline Message exchangeRetry(
const std::shared_ptr<mav::Connection> &connection,
Message &request,
const std::string &response_message_name,
int source_id = mav::ANY_ID,
int source_component = mav::ANY_ID,
int retries = 3,
int timeout_ms = 1000) {
return _retry<Message, const std::shared_ptr<Connection>&, Message&, const std::string&, int, int, int>
(retries, exchange, connection, request, response_message_name, source_id,
source_component, timeout_ms);
}

inline Message exchangeRetryAnyResponse(
const std::shared_ptr<mav::Connection> &connection,
const MessageSet &message_set,
Message &request,
const std::vector<std::string> &response_message_names,
int source_id = mav::ANY_ID,
int source_component = mav::ANY_ID,
int retries = 3,
int timeout_ms = 1000) {
return _retry<Message, const std::shared_ptr<Connection>&, const MessageSet&, Message&, const std::vector<std::string>&
, int, int, int>
(retries, exchangeAnyResponse, connection, message_set, request, response_message_names, source_id,
source_component, timeout_ms);
}
}

Expand Down
Loading

0 comments on commit 4f6a688

Please sign in to comment.