diff --git a/inc_internal/message.h b/inc_internal/message.h index d933c80a..e9ffb50f 100644 --- a/inc_internal/message.h +++ b/inc_internal/message.h @@ -27,7 +27,10 @@ #define MAGIC_INIT {0x3, 0x6, 0x9, 0xC} -typedef char magic_t[4]; +typedef union { + char magic[4]; + int32_t magint; +} magic_t; #define HEADER_FIELDS(XX) \ XX(content, uint32_t)\ @@ -52,7 +55,7 @@ static header_t EMPTY_HEADER = { typedef struct { uint32_t header_id; uint32_t length; - uint8_t *value; + const uint8_t *value; } hdr_t; #define var_header(id, var) header(id, sizeof(var), &(var)) @@ -97,9 +100,9 @@ bool message_get_bytes_header(message *m, int header_id, const uint8_t **ptr, si uint8_t *write_hdr(const hdr_t *h, uint8_t *buf); -int parse_hdrs(uint8_t *buf, uint32_t len, hdr_t **hp); +int parse_hdrs(const uint8_t *buf, uint32_t len, hdr_t **hp); -message *message_new_from_header(pool_t *pool, uint8_t buf[HEADER_SIZE]); +int message_new_from_header(pool_t *pool, uint8_t buf[HEADER_SIZE], message **msg_p); message *message_new(pool_t *pool, uint32_t content, const hdr_t *headers, int nheaders, size_t body_len); diff --git a/includes/ziti/errors.h b/includes/ziti/errors.h index 84267962..1f7d6ff7 100644 --- a/includes/ziti/errors.h +++ b/includes/ziti/errors.h @@ -111,6 +111,8 @@ is offline or did not respond to the request*/ #define ZITI_CERT_FAILED_VALIDATION (-36) /** returned when the certificate doesn't have an externalId") \*/ #define ZITI_MISSING_CERT_CLAIM (-37) +/** ziti could not allocate memory */ +#define ZITI_ALLOC_FAILED (-38) // Put new error codes here and add error string in error.c diff --git a/library/channel.c b/library/channel.c index a33e4aba..1cfbf79b 100644 --- a/library/channel.c +++ b/library/channel.c @@ -508,8 +508,6 @@ static bool is_edge(uint32_t content) { static void dispatch_message(ziti_channel_t *ch, message *m) { struct waiter_s *w = NULL; - m->nhdrs = parse_hdrs(m->headers, m->header.headers_len, &m->hdrs); - uint32_t reply_to; bool is_reply = message_get_int32_header(m, ReplyForHeader, (int32_t*)&reply_to); @@ -566,6 +564,7 @@ static void dispatch_message(ziti_channel_t *ch, message *m) { static void process_inbound(ziti_channel_t *ch) { uint8_t *ptr; ssize_t len; + int rc = 0; do { if (ch->in_next == NULL && pool_has_available(ch->in_msg_pool)) { if (buffer_available(ch->incoming) < HEADER_SIZE) { @@ -583,7 +582,9 @@ static void process_inbound(ziti_channel_t *ch) { assert(header_read == HEADER_SIZE); - ch->in_next = message_new_from_header(ch->in_msg_pool, header_buf); + + rc = message_new_from_header(ch->in_msg_pool, header_buf, &ch->in_next); + if (rc != ZITI_OK) break; ch->in_body_offset = 0; CH_LOG(TRACE, "<= ct[%04X] seq[%d] len[%d] hdrs[%d]", ch->in_next->header.content, @@ -608,17 +609,29 @@ static void process_inbound(ziti_channel_t *ch) { ch->in_body_offset += len; if (ch->in_body_offset == total) { - CH_LOG(TRACE, "message is complete seq[%d] ct[%04X]", ch->in_next->header.seq, - ch->in_next->header.content); - - dispatch_message(ch, ch->in_next); - + message *msg = ch->in_next; ch->in_next = NULL; + + CH_LOG(TRACE, "message is complete seq[%d] ct[%04X]", + msg->header.seq, msg->header.content); + + rc = parse_hdrs(msg->headers, msg->header.headers_len, &msg->hdrs); + if (rc < 0) { + pool_return_obj(msg); + CH_LOG(ERROR, "failed to parse incoming message: %s", ziti_errorstr(rc)); + break; + } + msg->nhdrs = rc; + rc = 0; + dispatch_message(ch, msg); } } } while (1); buffer_cleanup(ch->incoming); + if (rc != 0) { + on_channel_close(ch, rc, 0); + } } static void latency_reply_cb(void *ctx, message *reply, int err) { diff --git a/library/errors.c b/library/errors.c index 3b9f6b01..d02c17fe 100644 --- a/library/errors.c +++ b/library/errors.c @@ -53,6 +53,7 @@ XX(CERT_IN_USE,"the provided certificate already in use") \ XX(CERT_FAILED_VALIDATION, "the provided key/cert are invalid") \ XX(MISSING_CERT_CLAIM, "the certificate is expected to contain an externalId but none was not found") \ + XX(ALLOC_FAILED, "memory allocation failed") \ XX(WTF, "WTF: programming error") diff --git a/library/message.c b/library/message.c index 35516133..7596fa61 100644 --- a/library/message.c +++ b/library/message.c @@ -15,6 +15,7 @@ #include "message.h" #include #include +#include #include "utils.h" #include "endian_internal.h" @@ -25,7 +26,7 @@ static const uint8_t *read_int32(const uint8_t *p, uint32_t *val) { } void header_to_buffer(header_t *h, uint8_t *buf) { - memcpy(buf, h->magic, sizeof(h->magic)); + memcpy(buf, h->magic.magic, sizeof(h->magic)); uint8_t *offset = buf + sizeof(h->magic); #define write_field(n,t) {\ @@ -38,7 +39,7 @@ void header_to_buffer(header_t *h, uint8_t *buf) { }; void header_from_buffer(header_t *h, uint8_t *buf) { - memcpy(h->magic, buf, sizeof(h->magic)); + memcpy(h->magic.magic, buf, sizeof(h->magic)); uint8_t *offset = buf + sizeof(h->magic); #define read_field(n,t) {\ @@ -83,26 +84,39 @@ uint8_t *write_hdr(const hdr_t *h, uint8_t *buf) { return buf + h->length; } -int parse_hdrs(uint8_t *buf, uint32_t len, hdr_t **hp) { +int parse_hdrs(const uint8_t *buf, uint32_t len, hdr_t **hp) { const uint8_t *p = buf; - hdr_t *headers = NULL; int count = 0; - while (p < buf + len) { - if (headers == NULL) { - headers = malloc(sizeof(hdr_t)); - } - else { - headers = realloc(headers, (count + 1) * sizeof(hdr_t)); - } - - p = read_int32(p, &headers[count].header_id); - p = read_int32(p, &headers[count].length); - headers[count].value = p; - p += headers[count].length; + while (p < buf + len - 2 * sizeof(uint32_t)) { + uint32_t length; + p += sizeof(uint32_t); + p = read_int32(p, &length); + p += length; count++; } + if (p != buf + len) { + ZITI_LOG(ERROR, "misaligned message headers"); + return ZITI_INVALID_STATE; + } + + hdr_t *headers = calloc(count, sizeof(hdr_t)); + if (headers == NULL) { + ZITI_LOG(ERROR, "failed to allocates message headers"); + return ZITI_ALLOC_FAILED; + } + + p = buf; + int idx = 0; + while (p < buf + len) { + p = read_int32(p, &headers[idx].header_id); + p = read_int32(p, &headers[idx].length); + headers[idx].value = p; + p += headers[idx].length; + idx++; + } + *hp = headers; return count; } @@ -119,7 +133,7 @@ static hdr_t *find_header(message *m, int header_id) { bool message_get_bool_header(message *m, int header_id, bool *v) { hdr_t *h = find_header(m, header_id); if (h != NULL) { - int8_t val = *h->value; + char val = (char)h->value[0]; *v = (val != 0); return true; } @@ -165,19 +179,30 @@ bool message_get_bytes_header(message *m, int header_id, const uint8_t **v, size return false; } -message *message_new_from_header(pool_t *pool, uint8_t buf[HEADER_SIZE]) { +int message_new_from_header(pool_t *pool, uint8_t buf[HEADER_SIZE], message **msg_p) { header_t h; header_from_buffer(&h, buf); + if (h.magic.magint != EMPTY_HEADER.magic.magint) { + return ZITI_INVALID_STATE; + } + size_t msgbuflen = HEADER_SIZE + h.headers_len + h.body_len; message *m = pool ? pool_alloc_obj(pool) : alloc_unpooled_obj(sizeof(message) + msgbuflen, (void (*)(void *)) message_free); + if (m == NULL) { + return ZITI_ALLOC_FAILED; + } m->msgbuflen = msgbuflen; size_t msgsize = sizeof(message) + msgbuflen; if (msgsize > pool_obj_size(m)) { m->msgbufp = malloc(msgbuflen); + if (m->msgbufp == NULL) { + pool_return_obj(m); + return ZITI_ALLOC_FAILED; + } } else { m->msgbufp = m->msgbuf; @@ -186,7 +211,8 @@ message *message_new_from_header(pool_t *pool, uint8_t buf[HEADER_SIZE]) { memcpy(&m->header, &h, sizeof(h)); m->headers = m->msgbufp + HEADER_SIZE; m->body = m->headers + h.headers_len; - return m; + *msg_p = m; + return ZITI_OK; } message *message_new(pool_t *pool, uint32_t content, const hdr_t *hdrs, int nhdrs, size_t body_len) { diff --git a/programs/ziti-prox-c/main.cpp b/programs/ziti-prox-c/main.cpp index 99a974f2..cc540a34 100644 --- a/programs/ziti-prox-c/main.cpp +++ b/programs/ziti-prox-c/main.cpp @@ -32,7 +32,7 @@ class Run: public CLI::App { Run(): App("run proxy", "run"), debug(2) { - add_option("--debug,-d", debug, "log level")->envname("ZITI_LOG"); + add_option("--debug,-d", debug, "log level"); add_option("--identity,-i", identity, "identity config")->required(); add_option("listener", intercepts, ""); add_option("--bind,-b", bindings, "bind service "); diff --git a/tests/message_tests.cpp b/tests/message_tests.cpp index c0536155..eb320b5b 100644 --- a/tests/message_tests.cpp +++ b/tests/message_tests.cpp @@ -17,6 +17,7 @@ #include #include "message.h" #include "edge_protocol.h" +#include "ziti/errors.h" TEST_CASE("simple", "[model]") { auto p = pool_new(sizeof(message) + 200, 3, (void (*)(void *)) message_free); @@ -39,7 +40,8 @@ TEST_CASE("simple", "[model]") { strncpy(reinterpret_cast(m1->body), content1, strlen(content1)); message_set_seq(m1, &s1); - auto m2 = message_new_from_header(p, m1->msgbufp); + message *m2; + REQUIRE(message_new_from_header(p, m1->msgbufp, &m2) == ZITI_OK); CHECK(m2->header.seq == 3334); CHECK(m2->msgbuflen == m1->msgbuflen); memcpy(m2->msgbufp, m1->msgbufp, m1->msgbuflen); @@ -83,7 +85,8 @@ TEST_CASE("large", "[model]") { message_set_seq(m1, &seq); - auto m2 = message_new_from_header(p, m1->msgbufp); + message *m2; + REQUIRE(message_new_from_header(p, m1->msgbufp, &m2) == ZITI_OK); CHECK(m2->header.seq == 3334); CHECK(seq == 3334); CHECK(m2->msgbuflen == m1->msgbuflen); @@ -126,7 +129,8 @@ TEST_CASE("large unpooled", "[model]") { strncpy(reinterpret_cast(m1->body), content1, strlen(content1)); message_set_seq(m1, &seq); - auto m2 = message_new_from_header(nullptr, m1->msgbufp); + message *m2; + REQUIRE(message_new_from_header(nullptr, m1->msgbufp, &m2) == ZITI_OK); CHECK(m2->header.seq == 3334); CHECK(seq == 3334); CHECK(m2->msgbuflen == m1->msgbuflen);