Skip to content

Commit

Permalink
simplify concurrency guarding
Browse files Browse the repository at this point in the history
  • Loading branch information
bertmelis committed May 27, 2024
1 parent 983015f commit cb0c386
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 45 deletions.
92 changes: 51 additions & 41 deletions src/MqttClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,10 @@ bool MqttClient::connect() {
}
#endif
} else {
EMC_SEMAPHORE_GIVE();
emc_log_e("Could not create CONNECT packet");
EMC_SEMAPHORE_GIVE();
_onError(0, Error::OUT_OF_MEMORY);
EMC_SEMAPHORE_TAKE();
}
EMC_SEMAPHORE_GIVE();
}
Expand Down Expand Up @@ -151,7 +152,9 @@ uint16_t MqttClient::publish(const char* topic, uint8_t qos, bool retain, const
uint16_t packetId = (qos > 0) ? _getNextPacketId() : 1;
if (!_addPacket(packetId, topic, payload, length, qos, retain)) {
emc_log_e("Could not create PUBLISH packet");
EMC_SEMAPHORE_GIVE();
_onError(packetId, Error::OUT_OF_MEMORY);
EMC_SEMAPHORE_TAKE();
packetId = 0;
}
EMC_SEMAPHORE_GIVE();
Expand All @@ -175,15 +178,19 @@ uint16_t MqttClient::publish(const char* topic, uint8_t qos, bool retain, espMqt
uint16_t packetId = (qos > 0) ? _getNextPacketId() : 1;
if (!_addPacket(packetId, topic, callback, length, qos, retain)) {
emc_log_e("Could not create PUBLISH packet");
EMC_SEMAPHORE_GIVE();
_onError(packetId, Error::OUT_OF_MEMORY);
EMC_SEMAPHORE_TAKE();
packetId = 0;
}
EMC_SEMAPHORE_GIVE();
return packetId;
}

void MqttClient::clearQueue(bool deleteSessionData) {
EMC_SEMAPHORE_TAKE();
_clearQueue(deleteSessionData ? 2 : 0);
EMC_SEMAPHORE_GIVE();
}

const char* MqttClient::getClientId() const {
Expand Down Expand Up @@ -227,9 +234,11 @@ void MqttClient::loop() {
case State::connectingMqtt:
#if EMC_WAIT_FOR_CONNACK
if (_transport->connected()) {
EMC_SEMAPHORE_TAKE();
_sendPacket();
_checkIncoming();
_checkPing();
EMC_SEMAPHORE_GIVE();
} else {
_setState(State::disconnectingTcp1);
_disconnectReason = DisconnectReason::TCP_DISCONNECTED;
Expand All @@ -246,10 +255,12 @@ void MqttClient::loop() {
case State::disconnectingMqtt2:
if (_transport->connected()) {
// CONNECT packet is first in the queue
EMC_SEMAPHORE_TAKE();
_checkOutbox();
_checkIncoming();
_checkPing();
_checkTimeout();
EMC_SEMAPHORE_GIVE();
} else {
_setState(State::disconnectingTcp1);
_disconnectReason = DisconnectReason::TCP_DISCONNECTED;
Expand All @@ -262,26 +273,31 @@ void MqttClient::loop() {
EMC_SEMAPHORE_GIVE();
emc_log_e("Could not create DISCONNECT packet");
_onError(0, Error::OUT_OF_MEMORY);
EMC_SEMAPHORE_TAKE();
} else {
_setState(State::disconnectingMqtt2);
}
}
EMC_SEMAPHORE_GIVE();
_checkOutbox();
_checkIncoming();
_checkPing();
_checkTimeout();
EMC_SEMAPHORE_GIVE();
break;
case State::disconnectingTcp1:
_transport->stop();
_setState(State::disconnectingTcp2);
break; // keep break to accomodate async clients
case State::disconnectingTcp2:
if (_transport->disconnected()) {
EMC_SEMAPHORE_TAKE();
_clearQueue(0);
EMC_SEMAPHORE_GIVE();
_bytesSent = 0;
_setState(State::disconnected);
if (_onDisconnectCallback) _onDisconnectCallback(_disconnectReason);
if (_onDisconnectCallback) {
_onDisconnectCallback(_disconnectReason);
}
}
break;
// all cases covered, no default case
Expand Down Expand Up @@ -332,14 +348,12 @@ void MqttClient::_checkOutbox() {
}

int MqttClient::_sendPacket() {
EMC_SEMAPHORE_TAKE();
OutgoingPacket* packet = _outbox.getCurrent();

size_t written = 0;
if (packet) {
size_t wantToWrite = packet->packet.available(_bytesSent);
if (wantToWrite == 0) {
EMC_SEMAPHORE_GIVE();
return 0;
}
written = _transport->write(packet->packet.data(_bytesSent), wantToWrite);
Expand All @@ -348,12 +362,10 @@ int MqttClient::_sendPacket() {
_bytesSent += written;
emc_log_i("tx %zu/%zu (%02x)", _bytesSent, packet->packet.size(), packet->packet.packetType());
}
EMC_SEMAPHORE_GIVE();
return written;
}

bool MqttClient::_advanceOutbox() {
EMC_SEMAPHORE_TAKE();
OutgoingPacket* packet = _outbox.getCurrent();
if (packet && _bytesSent == packet->packet.size()) {
if ((packet->packet.packetType()) == PacketType.DISCONNECT) {
Expand All @@ -370,7 +382,6 @@ bool MqttClient::_advanceOutbox() {
packet = _outbox.getCurrent();
_bytesSent = 0;
}
EMC_SEMAPHORE_GIVE();
return packet;
}

Expand All @@ -390,7 +401,7 @@ void MqttClient::_checkIncoming() {
_setState(State::disconnectingTcp1);
return;
}
switch (packetType & 0xF0) {
switch (packetType) {
case PacketType.CONNACK:
_onConnack();
if (_state != State::connected) {
Expand Down Expand Up @@ -455,19 +466,15 @@ void MqttClient::_checkPing() {
if (!_pingSent &&
((currentMillis - _lastClientActivity > _keepAlive) ||
(currentMillis - _lastServerActivity > _keepAlive))) {
EMC_SEMAPHORE_TAKE();
if (!_addPacket(PacketType.PINGREQ)) {
EMC_SEMAPHORE_GIVE();
emc_log_e("Could not create PING packet");
return;
}
EMC_SEMAPHORE_GIVE();
_pingSent = true;
}
}

void MqttClient::_checkTimeout() {
EMC_SEMAPHORE_TAKE();
espMqttClientInternals::Outbox<OutgoingPacket>::Iterator it = _outbox.front();
// check that we're not busy sending
// don't check when first item hasn't been sent yet
Expand All @@ -477,7 +484,6 @@ void MqttClient::_checkTimeout() {
_outbox.resetCurrent();
}
}
EMC_SEMAPHORE_GIVE();
}

void MqttClient::_onConnack() {
Expand All @@ -489,7 +495,9 @@ void MqttClient::_onConnack() {
_clearQueue(1);
}
if (_onConnectCallback) {
EMC_SEMAPHORE_GIVE();
_onConnectCallback(_parser.getPacket().variableHeader.fixed.connackVarHeader.sessionPresent);
EMC_SEMAPHORE_TAKE();
}
} else {
_setState(State::disconnectingTcp1);
Expand All @@ -507,14 +515,11 @@ void MqttClient::_onPublish() {
bool callback = true;
if (qos == 1) {
if (p.payload.index + p.payload.length == p.payload.total) {
EMC_SEMAPHORE_TAKE();
if (!_addPacket(PacketType.PUBACK, packetId)) {
emc_log_e("Could not create PUBACK packet");
}
EMC_SEMAPHORE_GIVE();
}
} else if (qos == 2) {
EMC_SEMAPHORE_TAKE();
espMqttClientInternals::Outbox<OutgoingPacket>::Iterator it = _outbox.front();
while (it) {
if ((it.get()->packet.packetType()) == PacketType.PUBREC && it.get()->packet.packetId() == packetId) {
Expand All @@ -529,20 +534,22 @@ void MqttClient::_onPublish() {
emc_log_e("Could not create PUBREC packet");
}
}
}
if (callback && _onMessageCallback) {
EMC_SEMAPHORE_GIVE();
_onMessageCallback({qos, dup, retain, packetId},
p.variableHeader.topic,
p.payload.data,
p.payload.length,
p.payload.index,
p.payload.total);
EMC_SEMAPHORE_TAKE();
}
if (callback && _onMessageCallback) _onMessageCallback({qos, dup, retain, packetId},
p.variableHeader.topic,
p.payload.data,
p.payload.length,
p.payload.index,
p.payload.total);
}

void MqttClient::_onPuback() {
bool callback = false;
uint16_t idToMatch = _parser.getPacket().variableHeader.fixed.packetId;
EMC_SEMAPHORE_TAKE();
espMqttClientInternals::Outbox<OutgoingPacket>::Iterator it = _outbox.front();
while (it) {
// PUBACKs come in the order PUBs are sent. So we only check the first PUB packet in outbox
Expand All @@ -558,9 +565,12 @@ void MqttClient::_onPuback() {
}
++it;
}
EMC_SEMAPHORE_GIVE();
if (callback) {
if (_onPublishCallback) _onPublishCallback(idToMatch);
if (_onPublishCallback) {
EMC_SEMAPHORE_GIVE();
_onPublishCallback(idToMatch);
EMC_SEMAPHORE_TAKE();
}
} else {
emc_log_w("No matching PUBLISH packet found");
}
Expand All @@ -569,7 +579,6 @@ void MqttClient::_onPuback() {
void MqttClient::_onPubrec() {
bool success = false;
uint16_t idToMatch = _parser.getPacket().variableHeader.fixed.packetId;
EMC_SEMAPHORE_TAKE();
espMqttClientInternals::Outbox<OutgoingPacket>::Iterator it = _outbox.front();
while (it) {
// PUBRECs come in the order PUBs are sent. So we only check the first PUB packet in outbox
Expand All @@ -591,13 +600,11 @@ void MqttClient::_onPubrec() {
if (!success) {
emc_log_w("No matching PUBLISH packet found");
}
EMC_SEMAPHORE_GIVE();
}

void MqttClient::_onPubrel() {
bool success = false;
uint16_t idToMatch = _parser.getPacket().variableHeader.fixed.packetId;
EMC_SEMAPHORE_TAKE();
espMqttClientInternals::Outbox<OutgoingPacket>::Iterator it = _outbox.front();
while (it) {
// PUBRELs come in the order PUBRECs are sent. So we only check the first PUBREC packet in outbox
Expand All @@ -619,12 +626,10 @@ void MqttClient::_onPubrel() {
if (!success) {
emc_log_w("No matching PUBREC packet found");
}
EMC_SEMAPHORE_GIVE();
}

void MqttClient::_onPubcomp() {
bool callback = false;
EMC_SEMAPHORE_TAKE();
espMqttClientInternals::Outbox<OutgoingPacket>::Iterator it = _outbox.front();
uint16_t idToMatch = _parser.getPacket().variableHeader.fixed.packetId;
while (it) {
Expand All @@ -641,9 +646,12 @@ void MqttClient::_onPubcomp() {
}
++it;
}
EMC_SEMAPHORE_GIVE();
if (callback) {
if (_onPublishCallback) _onPublishCallback(idToMatch);
if (_onPublishCallback) {
EMC_SEMAPHORE_GIVE();
_onPublishCallback(idToMatch);
EMC_SEMAPHORE_TAKE();
}
} else {
emc_log_w("No matching PUBREL packet found");
}
Expand All @@ -652,7 +660,6 @@ void MqttClient::_onPubcomp() {
void MqttClient::_onSuback() {
bool callback = false;
uint16_t idToMatch = _parser.getPacket().variableHeader.fixed.packetId;
EMC_SEMAPHORE_TAKE();
espMqttClientInternals::Outbox<OutgoingPacket>::Iterator it = _outbox.front();
while (it) {
if (((it.get()->packet.packetType()) == PacketType.SUBSCRIBE) && it.get()->packet.packetId() == idToMatch) {
Expand All @@ -662,17 +669,19 @@ void MqttClient::_onSuback() {
}
++it;
}
EMC_SEMAPHORE_GIVE();
if (callback) {
if (_onSubscribeCallback) _onSubscribeCallback(idToMatch, reinterpret_cast<const espMqttClientTypes::SubscribeReturncode*>(_parser.getPacket().payload.data), _parser.getPacket().payload.total);
if (_onSubscribeCallback) {
EMC_SEMAPHORE_GIVE();
_onSubscribeCallback(idToMatch, reinterpret_cast<const espMqttClientTypes::SubscribeReturncode*>(_parser.getPacket().payload.data), _parser.getPacket().payload.total);
EMC_SEMAPHORE_TAKE();
}
} else {
emc_log_w("received SUBACK without SUB");
}
}

void MqttClient::_onUnsuback() {
bool callback = false;
EMC_SEMAPHORE_TAKE();
espMqttClientInternals::Outbox<OutgoingPacket>::Iterator it = _outbox.front();
uint16_t idToMatch = _parser.getPacket().variableHeader.fixed.packetId;
while (it) {
Expand All @@ -683,17 +692,19 @@ void MqttClient::_onUnsuback() {
}
++it;
}
EMC_SEMAPHORE_GIVE();
if (callback) {
if (_onUnsubscribeCallback) _onUnsubscribeCallback(idToMatch);
if (_onUnsubscribeCallback) {
EMC_SEMAPHORE_GIVE();
_onUnsubscribeCallback(idToMatch);
EMC_SEMAPHORE_TAKE();
}
} else {
emc_log_w("received UNSUBACK without UNSUB");
}
}

void MqttClient::_clearQueue(int clearData) {
emc_log_i("clearing queue (clear session: %d)", clearData);
EMC_SEMAPHORE_TAKE();
espMqttClientInternals::Outbox<OutgoingPacket>::Iterator it = _outbox.front();
if (clearData == 0) {
// keep PUB (qos > 0, aka packetID != 0), PUBREC and PUBREL
Expand Down Expand Up @@ -723,7 +734,6 @@ void MqttClient::_clearQueue(int clearData) {
_outbox.remove(it);
}
}
EMC_SEMAPHORE_GIVE();
}

void MqttClient::_onError(uint16_t packetId, espMqttClientTypes::Error error) {
Expand Down
10 changes: 6 additions & 4 deletions src/MqttClient.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,12 @@ class MqttClient {
bool disconnect(bool force = false);
template <typename... Args>
uint16_t subscribe(const char* topic, uint8_t qos, Args&&... args) {
uint16_t packetId = _getNextPacketId();
uint16_t packetId = 0;
if (_state != State::connected) {
packetId = 0;
return packetId;
} else {
EMC_SEMAPHORE_TAKE();
packetId = _getNextPacketId();
if (!_addPacket(packetId, topic, qos, std::forward<Args>(args) ...)) {
emc_log_e("Could not create SUBSCRIBE packet");
packetId = 0;
Expand All @@ -47,11 +48,12 @@ class MqttClient {
}
template <typename... Args>
uint16_t unsubscribe(const char* topic, Args&&... args) {
uint16_t packetId = _getNextPacketId();
uint16_t packetId = 0;
if (_state != State::connected) {
packetId = 0;
return packetId;
} else {
EMC_SEMAPHORE_TAKE();
packetId = _getNextPacketId();
if (!_addPacket(packetId, topic, std::forward<Args>(args) ...)) {
emc_log_e("Could not create UNSUBSCRIBE packet");
packetId = 0;
Expand Down

0 comments on commit cb0c386

Please sign in to comment.