From 34b47c01e48ec5b2eaf59964a01c69f4d6476e26 Mon Sep 17 00:00:00 2001 From: Bert Melis Date: Wed, 29 May 2024 19:58:06 +0200 Subject: [PATCH] simplify concurrency guarding --- src/MqttClient.cpp | 92 +++++++++++++++++++++++++--------------------- src/MqttClient.h | 10 +++-- 2 files changed, 57 insertions(+), 45 deletions(-) diff --git a/src/MqttClient.cpp b/src/MqttClient.cpp index fdc2cb2..ce70f24 100644 --- a/src/MqttClient.cpp +++ b/src/MqttClient.cpp @@ -118,9 +118,10 @@ bool MqttClient::connect() { } #endif } else { - EMC_SEMAPHORE_GIVE(); emc_log_e("Could not create CONNECT packet"); + EMC_SEMAPHORE_GIVE(); _onError(0, Error::OUT_OF_MEMORY); + EMC_SEMAPHORE_TAKE(); } EMC_SEMAPHORE_GIVE(); } @@ -151,7 +152,9 @@ uint16_t MqttClient::publish(const char* topic, uint8_t qos, bool retain, const uint16_t packetId = (qos > 0) ? _getNextPacketId() : 1; if (!_addPacket(packetId, topic, payload, length, qos, retain)) { emc_log_e("Could not create PUBLISH packet"); + EMC_SEMAPHORE_GIVE(); _onError(packetId, Error::OUT_OF_MEMORY); + EMC_SEMAPHORE_TAKE(); packetId = 0; } EMC_SEMAPHORE_GIVE(); @@ -175,7 +178,9 @@ uint16_t MqttClient::publish(const char* topic, uint8_t qos, bool retain, espMqt uint16_t packetId = (qos > 0) ? _getNextPacketId() : 1; if (!_addPacket(packetId, topic, callback, length, qos, retain)) { emc_log_e("Could not create PUBLISH packet"); + EMC_SEMAPHORE_GIVE(); _onError(packetId, Error::OUT_OF_MEMORY); + EMC_SEMAPHORE_TAKE(); packetId = 0; } EMC_SEMAPHORE_GIVE(); @@ -183,7 +188,9 @@ uint16_t MqttClient::publish(const char* topic, uint8_t qos, bool retain, espMqt } void MqttClient::clearQueue(bool deleteSessionData) { + EMC_SEMAPHORE_TAKE(); _clearQueue(deleteSessionData ? 2 : 0); + EMC_SEMAPHORE_GIVE(); } const char* MqttClient::getClientId() const { @@ -227,9 +234,11 @@ void MqttClient::loop() { case State::connectingMqtt: #if EMC_WAIT_FOR_CONNACK if (_transport->connected()) { + EMC_SEMAPHORE_TAKE(); _sendPacket(); _checkIncoming(); _checkPing(); + EMC_SEMAPHORE_GIVE(); } else { _setState(State::disconnectingTcp1); _disconnectReason = DisconnectReason::TCP_DISCONNECTED; @@ -246,10 +255,12 @@ void MqttClient::loop() { case State::disconnectingMqtt2: if (_transport->connected()) { // CONNECT packet is first in the queue + EMC_SEMAPHORE_TAKE(); _checkOutbox(); _checkIncoming(); _checkPing(); _checkTimeout(); + EMC_SEMAPHORE_GIVE(); } else { _setState(State::disconnectingTcp1); _disconnectReason = DisconnectReason::TCP_DISCONNECTED; @@ -262,15 +273,16 @@ void MqttClient::loop() { EMC_SEMAPHORE_GIVE(); emc_log_e("Could not create DISCONNECT packet"); _onError(0, Error::OUT_OF_MEMORY); + EMC_SEMAPHORE_TAKE(); } else { _setState(State::disconnectingMqtt2); } } - EMC_SEMAPHORE_GIVE(); _checkOutbox(); _checkIncoming(); _checkPing(); _checkTimeout(); + EMC_SEMAPHORE_GIVE(); break; case State::disconnectingTcp1: _transport->stop(); @@ -278,10 +290,14 @@ void MqttClient::loop() { break; // keep break to accomodate async clients case State::disconnectingTcp2: if (_transport->disconnected()) { + EMC_SEMAPHORE_TAKE(); _clearQueue(0); + EMC_SEMAPHORE_GIVE(); _bytesSent = 0; _setState(State::disconnected); - if (_onDisconnectCallback) _onDisconnectCallback(_disconnectReason); + if (_onDisconnectCallback) { + _onDisconnectCallback(_disconnectReason); + } } break; // all cases covered, no default case @@ -332,14 +348,12 @@ void MqttClient::_checkOutbox() { } int MqttClient::_sendPacket() { - EMC_SEMAPHORE_TAKE(); OutgoingPacket* packet = _outbox.getCurrent(); size_t written = 0; if (packet) { size_t wantToWrite = packet->packet.available(_bytesSent); if (wantToWrite == 0) { - EMC_SEMAPHORE_GIVE(); return 0; } written = _transport->write(packet->packet.data(_bytesSent), wantToWrite); @@ -348,12 +362,10 @@ int MqttClient::_sendPacket() { _bytesSent += written; emc_log_i("tx %zu/%zu (%02x)", _bytesSent, packet->packet.size(), packet->packet.packetType()); } - EMC_SEMAPHORE_GIVE(); return written; } bool MqttClient::_advanceOutbox() { - EMC_SEMAPHORE_TAKE(); OutgoingPacket* packet = _outbox.getCurrent(); if (packet && _bytesSent == packet->packet.size()) { if ((packet->packet.packetType()) == PacketType.DISCONNECT) { @@ -370,7 +382,6 @@ bool MqttClient::_advanceOutbox() { packet = _outbox.getCurrent(); _bytesSent = 0; } - EMC_SEMAPHORE_GIVE(); return packet; } @@ -390,7 +401,7 @@ void MqttClient::_checkIncoming() { _setState(State::disconnectingTcp1); return; } - switch (packetType & 0xF0) { + switch (packetType) { case PacketType.CONNACK: _onConnack(); if (_state != State::connected) { @@ -455,19 +466,15 @@ void MqttClient::_checkPing() { if (!_pingSent && ((currentMillis - _lastClientActivity > _keepAlive) || (currentMillis - _lastServerActivity > _keepAlive))) { - EMC_SEMAPHORE_TAKE(); if (!_addPacket(PacketType.PINGREQ)) { - EMC_SEMAPHORE_GIVE(); emc_log_e("Could not create PING packet"); return; } - EMC_SEMAPHORE_GIVE(); _pingSent = true; } } void MqttClient::_checkTimeout() { - EMC_SEMAPHORE_TAKE(); espMqttClientInternals::Outbox::Iterator it = _outbox.front(); // check that we're not busy sending // don't check when first item hasn't been sent yet @@ -477,7 +484,6 @@ void MqttClient::_checkTimeout() { _outbox.resetCurrent(); } } - EMC_SEMAPHORE_GIVE(); } void MqttClient::_onConnack() { @@ -489,7 +495,9 @@ void MqttClient::_onConnack() { _clearQueue(1); } if (_onConnectCallback) { + EMC_SEMAPHORE_GIVE(); _onConnectCallback(_parser.getPacket().variableHeader.fixed.connackVarHeader.sessionPresent); + EMC_SEMAPHORE_TAKE(); } } else { _setState(State::disconnectingTcp1); @@ -507,14 +515,11 @@ void MqttClient::_onPublish() { bool callback = true; if (qos == 1) { if (p.payload.index + p.payload.length == p.payload.total) { - EMC_SEMAPHORE_TAKE(); if (!_addPacket(PacketType.PUBACK, packetId)) { emc_log_e("Could not create PUBACK packet"); } - EMC_SEMAPHORE_GIVE(); } } else if (qos == 2) { - EMC_SEMAPHORE_TAKE(); espMqttClientInternals::Outbox::Iterator it = _outbox.front(); while (it) { if ((it.get()->packet.packetType()) == PacketType.PUBREC && it.get()->packet.packetId() == packetId) { @@ -529,20 +534,22 @@ void MqttClient::_onPublish() { emc_log_e("Could not create PUBREC packet"); } } + } + if (callback && _onMessageCallback) { EMC_SEMAPHORE_GIVE(); + _onMessageCallback({qos, dup, retain, packetId}, + p.variableHeader.topic, + p.payload.data, + p.payload.length, + p.payload.index, + p.payload.total); + EMC_SEMAPHORE_TAKE(); } - if (callback && _onMessageCallback) _onMessageCallback({qos, dup, retain, packetId}, - p.variableHeader.topic, - p.payload.data, - p.payload.length, - p.payload.index, - p.payload.total); } void MqttClient::_onPuback() { bool callback = false; uint16_t idToMatch = _parser.getPacket().variableHeader.fixed.packetId; - EMC_SEMAPHORE_TAKE(); espMqttClientInternals::Outbox::Iterator it = _outbox.front(); while (it) { // PUBACKs come in the order PUBs are sent. So we only check the first PUB packet in outbox @@ -558,9 +565,12 @@ void MqttClient::_onPuback() { } ++it; } - EMC_SEMAPHORE_GIVE(); if (callback) { - if (_onPublishCallback) _onPublishCallback(idToMatch); + if (_onPublishCallback) { + EMC_SEMAPHORE_GIVE(); + _onPublishCallback(idToMatch); + EMC_SEMAPHORE_TAKE(); + } } else { emc_log_w("No matching PUBLISH packet found"); } @@ -569,7 +579,6 @@ void MqttClient::_onPuback() { void MqttClient::_onPubrec() { bool success = false; uint16_t idToMatch = _parser.getPacket().variableHeader.fixed.packetId; - EMC_SEMAPHORE_TAKE(); espMqttClientInternals::Outbox::Iterator it = _outbox.front(); while (it) { // PUBRECs come in the order PUBs are sent. So we only check the first PUB packet in outbox @@ -591,13 +600,11 @@ void MqttClient::_onPubrec() { if (!success) { emc_log_w("No matching PUBLISH packet found"); } - EMC_SEMAPHORE_GIVE(); } void MqttClient::_onPubrel() { bool success = false; uint16_t idToMatch = _parser.getPacket().variableHeader.fixed.packetId; - EMC_SEMAPHORE_TAKE(); espMqttClientInternals::Outbox::Iterator it = _outbox.front(); while (it) { // PUBRELs come in the order PUBRECs are sent. So we only check the first PUBREC packet in outbox @@ -619,12 +626,10 @@ void MqttClient::_onPubrel() { if (!success) { emc_log_w("No matching PUBREC packet found"); } - EMC_SEMAPHORE_GIVE(); } void MqttClient::_onPubcomp() { bool callback = false; - EMC_SEMAPHORE_TAKE(); espMqttClientInternals::Outbox::Iterator it = _outbox.front(); uint16_t idToMatch = _parser.getPacket().variableHeader.fixed.packetId; while (it) { @@ -641,9 +646,12 @@ void MqttClient::_onPubcomp() { } ++it; } - EMC_SEMAPHORE_GIVE(); if (callback) { - if (_onPublishCallback) _onPublishCallback(idToMatch); + if (_onPublishCallback) { + EMC_SEMAPHORE_GIVE(); + _onPublishCallback(idToMatch); + EMC_SEMAPHORE_TAKE(); + } } else { emc_log_w("No matching PUBREL packet found"); } @@ -652,7 +660,6 @@ void MqttClient::_onPubcomp() { void MqttClient::_onSuback() { bool callback = false; uint16_t idToMatch = _parser.getPacket().variableHeader.fixed.packetId; - EMC_SEMAPHORE_TAKE(); espMqttClientInternals::Outbox::Iterator it = _outbox.front(); while (it) { if (((it.get()->packet.packetType()) == PacketType.SUBSCRIBE) && it.get()->packet.packetId() == idToMatch) { @@ -662,9 +669,12 @@ void MqttClient::_onSuback() { } ++it; } - EMC_SEMAPHORE_GIVE(); if (callback) { - if (_onSubscribeCallback) _onSubscribeCallback(idToMatch, reinterpret_cast(_parser.getPacket().payload.data), _parser.getPacket().payload.total); + if (_onSubscribeCallback) { + EMC_SEMAPHORE_GIVE(); + _onSubscribeCallback(idToMatch, reinterpret_cast(_parser.getPacket().payload.data), _parser.getPacket().payload.total); + EMC_SEMAPHORE_TAKE(); + } } else { emc_log_w("received SUBACK without SUB"); } @@ -672,7 +682,6 @@ void MqttClient::_onSuback() { void MqttClient::_onUnsuback() { bool callback = false; - EMC_SEMAPHORE_TAKE(); espMqttClientInternals::Outbox::Iterator it = _outbox.front(); uint16_t idToMatch = _parser.getPacket().variableHeader.fixed.packetId; while (it) { @@ -683,9 +692,12 @@ void MqttClient::_onUnsuback() { } ++it; } - EMC_SEMAPHORE_GIVE(); if (callback) { - if (_onUnsubscribeCallback) _onUnsubscribeCallback(idToMatch); + if (_onUnsubscribeCallback) { + EMC_SEMAPHORE_GIVE(); + _onUnsubscribeCallback(idToMatch); + EMC_SEMAPHORE_TAKE(); + } } else { emc_log_w("received UNSUBACK without UNSUB"); } @@ -693,7 +705,6 @@ void MqttClient::_onUnsuback() { void MqttClient::_clearQueue(int clearData) { emc_log_i("clearing queue (clear session: %d)", clearData); - EMC_SEMAPHORE_TAKE(); espMqttClientInternals::Outbox::Iterator it = _outbox.front(); if (clearData == 0) { // keep PUB (qos > 0, aka packetID != 0), PUBREC and PUBREL @@ -723,7 +734,6 @@ void MqttClient::_clearQueue(int clearData) { _outbox.remove(it); } } - EMC_SEMAPHORE_GIVE(); } void MqttClient::_onError(uint16_t packetId, espMqttClientTypes::Error error) { diff --git a/src/MqttClient.h b/src/MqttClient.h index 4997942..eaf9d2d 100644 --- a/src/MqttClient.h +++ b/src/MqttClient.h @@ -32,11 +32,12 @@ class MqttClient { bool disconnect(bool force = false); template uint16_t subscribe(const char* topic, uint8_t qos, Args&&... args) { - uint16_t packetId = _getNextPacketId(); + uint16_t packetId = 0; if (_state != State::connected) { - packetId = 0; + return packetId; } else { EMC_SEMAPHORE_TAKE(); + packetId = _getNextPacketId(); if (!_addPacket(packetId, topic, qos, std::forward(args) ...)) { emc_log_e("Could not create SUBSCRIBE packet"); packetId = 0; @@ -47,11 +48,12 @@ class MqttClient { } template uint16_t unsubscribe(const char* topic, Args&&... args) { - uint16_t packetId = _getNextPacketId(); + uint16_t packetId = 0; if (_state != State::connected) { - packetId = 0; + return packetId; } else { EMC_SEMAPHORE_TAKE(); + packetId = _getNextPacketId(); if (!_addPacket(packetId, topic, std::forward(args) ...)) { emc_log_e("Could not create UNSUBSCRIBE packet"); packetId = 0;