diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4ee0a59039..9a7b1fb1ce 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -44,6 +44,26 @@ jobs: - name: module api test run: CFLAGS='-Werror' ./runtest-moduleapi --verbose --dump-logs + test-rdma: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 + - name: make + run: | + sudo apt-get install librdmacm-dev libibverbs-dev + make BUILD_RDMA=module + - name: clone-rxe-kmod + run: | + mkdir -p tests/rdma/rxe + git clone https://github.com/pizhenwei/rxe.git tests/rdma/rxe + make -C tests/rdma/rxe + - name: clear-kernel-log + run: sudo dmesg -c > /dev/null + - name: test + run: sudo ./runtest-rdma --install-rxe + - name: show-kernel-log + run: sudo dmesg -c + build-debian-old: runs-on: ubuntu-latest container: debian:buster diff --git a/.gitignore b/.gitignore index bc53c98865..3175ad4b4f 100644 --- a/.gitignore +++ b/.gitignore @@ -47,3 +47,4 @@ redis.code-workspace .swp nodes*.conf tests/cluster/tmp/* +tests/rdma/rdma-test diff --git a/README.md b/README.md index 3c32be3941..592002d3c4 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,12 @@ libssl-dev on Debian/Ubuntu) and run: % make BUILD_TLS=yes +To build with experimental RDMA support you'll need RDMA development libraries +(e.g. librdmacm-dev and libibverbs-dev on Debian/Ubuntu). For now, Valkey only +supports RDMA as connection module mode. Run: + + % make BUILD_RDMA=module + To build with systemd support, you'll need systemd development libraries (such as libsystemd-dev on Debian/Ubuntu or systemd-devel on CentOS) and run: @@ -155,6 +161,38 @@ Running Valkey with TLS: Please consult the [TLS.md](TLS.md) file for more information on how to use Valkey with TLS. +Running Valkey with RDMA: +------------------ + +Note that Valkey Over RDMA is an experimental feature. +It may be changed or removed in any minor or major version. +Currently, it is only supported on Linux. + +To manually run a Valkey server with RDMA mode: + + % ./src/valkey-server --protected-mode no \ + --loadmodule src/valkey-rdma.so bind=192.168.122.100 port=6379 + +It's possible to change bind address/port of RDMA by runtime command: + + 192.168.122.100:6379> CONFIG SET rdma.port 6380 + +It's also possible to have both RDMA and TCP available, and there is no +conflict of TCP(6379) and RDMA(6379), Ex: + + % ./src/valkey-server --protected-mode no \ + --loadmodule src/valkey-rdma.so bind=192.168.122.100 port=6379 \ + --port 6379 + +Note that the network card (192.168.122.100 of this example) should support +RDMA. To test a server supports RDMA or not: + + % rdma res show (a new version iproute2 package) +Or: + + % ibv_devices + + Playing with Valkey ------------------ diff --git a/runtest-rdma b/runtest-rdma new file mode 100755 index 0000000000..5e44208f9f --- /dev/null +++ b/runtest-rdma @@ -0,0 +1 @@ +./tests/rdma/run.py $* diff --git a/src/Makefile b/src/Makefile index 4e8c34b253..eaf0e4e387 100644 --- a/src/Makefile +++ b/src/Makefile @@ -337,6 +337,28 @@ ifeq ($(BUILD_TLS),module) TLS_MODULE_CFLAGS+=-DUSE_OPENSSL=$(BUILD_MODULE) $(OPENSSL_CFLAGS) -DBUILD_TLS_MODULE=$(BUILD_MODULE) endif +BUILD_RDMA:=no +RDMA_MODULE= +RDMA_MODULE_NAME:=valkey-rdma$(PROG_SUFFIX).so +RDMA_MODULE_CFLAGS:=$(FINAL_CFLAGS) +ifeq ($(BUILD_RDMA),module) + FINAL_CFLAGS+=-DUSE_RDMA=$(BUILD_MODULE) + RDMA_PKGCONFIG := $(shell $(PKG_CONFIG) --exists librdmacm libibverbs && echo $$?) +ifeq ($(RDMA_PKGCONFIG),0) + RDMA_LIBS=$(shell $(PKG_CONFIG) --libs librdmacm libibverbs) +else + RDMA_LIBS=-lrdmacm -libverbs +endif + RDMA_MODULE=$(RDMA_MODULE_NAME) + RDMA_MODULE_CFLAGS+=-DUSE_RDMA=$(BUILD_YES) -DBUILD_RDMA_MODULE $(RDMA_LIBS) +else +ifeq ($(BUILD_RDMA),no) + # disable RDMA, do nothing +else + $(error "RDMA is only supported as module (BUILD_RDMA=module), or disabled (BUILD_RDMA=no)") +endif +endif + ifndef V define MAKE_INSTALL @printf ' %b %b\n' $(LINKCOLOR)INSTALL$(ENDCOLOR) $(BINCOLOR)$(1)$(ENDCOLOR) 1>&2 @@ -414,7 +436,7 @@ ENGINE_TEST_OBJ:=$(sort $(patsubst unit/%.c,unit/%.o,$(ENGINE_TEST_FILES))) ENGINE_UNIT_TESTS:=$(ENGINE_NAME)-unit-tests$(PROG_SUFFIX) ALL_SOURCES=$(sort $(patsubst %.o,%.c,$(ENGINE_SERVER_OBJ) $(ENGINE_CLI_OBJ) $(ENGINE_BENCHMARK_OBJ))) -all: $(SERVER_NAME) $(ENGINE_SENTINEL_NAME) $(ENGINE_CLI_NAME) $(ENGINE_BENCHMARK_NAME) $(ENGINE_CHECK_RDB_NAME) $(ENGINE_CHECK_AOF_NAME) $(TLS_MODULE) +all: $(SERVER_NAME) $(ENGINE_SENTINEL_NAME) $(ENGINE_CLI_NAME) $(ENGINE_BENCHMARK_NAME) $(ENGINE_CHECK_RDB_NAME) $(ENGINE_CHECK_AOF_NAME) $(TLS_MODULE) $(RDMA_MODULE) @echo "" @echo "Hint: It's a good idea to run 'make test' ;)" @echo "" @@ -437,6 +459,7 @@ persist-settings: distclean echo OPT=$(OPT) >> .make-settings echo MALLOC=$(MALLOC) >> .make-settings echo BUILD_TLS=$(BUILD_TLS) >> .make-settings + echo BUILD_RDMA=$(BUILD_RDMA) >> .make-settings echo USE_SYSTEMD=$(USE_SYSTEMD) >> .make-settings echo CFLAGS=$(CFLAGS) >> .make-settings echo LDFLAGS=$(LDFLAGS) >> .make-settings @@ -489,6 +512,10 @@ $(ENGINE_CHECK_AOF_NAME): $(SERVER_NAME) $(TLS_MODULE_NAME): $(SERVER_NAME) $(QUIET_CC)$(CC) -o $@ tls.c -shared -fPIC $(TLS_MODULE_CFLAGS) $(TLS_CLIENT_LIBS) +# valkey-rdma.so +$(RDMA_MODULE_NAME): $(REDIS_SERVER_NAME) + $(QUIET_CC)$(CC) -o $@ rdma.c -shared -fPIC $(RDMA_MODULE_CFLAGS) + # valkey-cli $(ENGINE_CLI_NAME): $(ENGINE_CLI_OBJ) $(SERVER_LD) -o $@ $^ ../deps/hiredis/libhiredis.a ../deps/linenoise/linenoise.o $(FINAL_LIBS) $(TLS_CLIENT_LIBS) diff --git a/src/rdma.c b/src/rdma.c new file mode 100644 index 0000000000..f95b5b1229 --- /dev/null +++ b/src/rdma.c @@ -0,0 +1,1888 @@ +/* ========================================================================== + * rdma.c - support RDMA protocol for transport layer. + * -------------------------------------------------------------------------- + * Copyright (C) 2021-2024 zhenwei pi + * + * This work is licensed under BSD 3-Clause, License 1 of the COPYING file in + * the top-level directory. + * ========================================================================== + */ + +#define VALKEYMODULE_CORE_MODULE +#include "server.h" + +#if defined USE_RDMA && defined __linux__ /* currently RDMA is only supported on Linux */ +#include "connection.h" +#include "connhelpers.h" + +#include +#include +#include +#include +#include +#include +#include + +#define CONN_TYPE_RDMA "rdma" + +typedef struct ValkeyRdmaFeature { + /* defined as following Opcodes */ + uint16_t opcode; + /* select features */ + uint16_t select; + uint8_t rsvd[20]; + /* feature bits */ + uint64_t features; +} ValkeyRdmaFeature; + +typedef struct ValkeyRdmaKeepalive { + /* defined as following Opcodes */ + uint16_t opcode; + uint8_t rsvd[30]; +} ValkeyRdmaKeepalive; + +typedef struct ValkeyRdmaMemory { + /* defined as following Opcodes */ + uint16_t opcode; + uint8_t rsvd[14]; + /* address of a transfer buffer which is used to receive remote streaming data, + * aka 'RX buffer address'. The remote side should use this as 'TX buffer address' */ + uint64_t addr; + /* length of the 'RX buffer' */ + uint32_t length; + /* the RDMA remote key of 'RX buffer' */ + uint32_t key; +} ValkeyRdmaMemory; + +typedef union ValkeyRdmaCmd { + ValkeyRdmaFeature feature; + ValkeyRdmaKeepalive keepalive; + ValkeyRdmaMemory memory; +} ValkeyRdmaCmd; + +typedef enum ValkeyRdmaOpcode { + GetServerFeature = 0, + SetClientFeature = 1, + Keepalive = 2, + RegisterXferMemory = 3, +} ValkeyRdmaOpcode; + +#define VALKEY_BUILD_BUG_ON(cond) ((void)sizeof(char[1 - 2 * !!(cond)])) +#define VALKEY_RDMA_MAX_WQE 1024 +#define VALKEY_RDMA_DEFAULT_RX_SIZE (1024 * 1024) +#define VALKEY_RDMA_MIN_RX_SIZE (64 * 1024) +#define VALKEY_RDMA_MAX_RX_SIZE (16 * 1024 * 1024) +#define VALKEY_RDMA_SYNCIO_RES 10 +#define VALKEY_RDMA_INVALID_OPCODE 0xffff + +typedef struct rdma_connection { + connection c; + struct rdma_cm_id *cm_id; + int last_errno; + listNode *pending_list_node; +} rdma_connection; + +typedef struct RdmaXfer { + struct ibv_mr *mr; /* memory region of the transfer buffer */ + char *addr; /* address of transfer buffer in local memory */ + uint32_t length; /* bytes of transfer buffer */ + uint32_t offset; /* the offset of consumed transfer buffer */ + uint32_t pos; /* the position in use of the transfer buffer */ +} RdmaXfer; + +typedef struct RdmaContext { + connection *conn; + char *ip; + int port; + struct ibv_pd *pd; + struct rdma_event_channel *cm_channel; + struct ibv_comp_channel *comp_channel; + struct ibv_cq *cq; + + /* TX */ + RdmaXfer tx; + char *tx_addr; /* remote transfer buffer address */ + uint32_t tx_key; /* remote transfer buffer key */ + uint32_t tx_length; /* remote transfer buffer length */ + uint32_t tx_offset; /* remote transfer buffer offset */ + uint32_t tx_ops; /* operations on remote transfer */ + + /* RX */ + RdmaXfer rx; + + /* CMD 0 ~ VALKEY_RDMA_MAX_WQE for recv buffer + * VALKEY_RDMA_MAX_WQE ~ 2 * VALKEY_RDMA_MAX_WQE -1 for send buffer */ + ValkeyRdmaCmd *cmd_buf; + struct ibv_mr *cmd_mr; +} RdmaContext; + +typedef struct rdma_listener { + struct rdma_cm_id *cm_id; + struct rdma_event_channel *cm_channel; +} rdma_listener; + +/* RDMA connection is always writable, it has no POLLOUT event to drive the write handler, record available write + * handler into pending list */ +static list *pending_list; + +static ConnectionType CT_RDMA; + +static int valkey_rdma_rx_size = VALKEY_RDMA_DEFAULT_RX_SIZE; +static int valkey_rdma_comp_vector = -1; /* -1 means a random one */ + +static void serverRdmaError(char *err, const char *fmt, ...) { + va_list ap; + + if (!err) return; + va_start(ap, fmt); + vsnprintf(err, ANET_ERR_LEN, fmt, ap); + va_end(ap); +} + +static int rdmaPostRecv(RdmaContext *ctx, struct rdma_cm_id *cm_id, ValkeyRdmaCmd *cmd) { + struct ibv_sge sge; + size_t length = sizeof(ValkeyRdmaCmd); + struct ibv_recv_wr recv_wr, *bad_wr; + int ret; + + sge.addr = (uint64_t)cmd; + sge.length = length; + sge.lkey = ctx->cmd_mr->lkey; + + recv_wr.wr_id = (uint64_t)cmd; + recv_wr.sg_list = &sge; + recv_wr.num_sge = 1; + recv_wr.next = NULL; + + ret = ibv_post_recv(cm_id->qp, &recv_wr, &bad_wr); + if (ret && (ret != EAGAIN)) { + serverLog(LL_WARNING, "RDMA: post recv failed: %d", ret); + return C_ERR; + } + + return C_OK; +} + +/* To make Valkey forkable, buffer which is registered as RDMA + * memory region should be aligned to page size. And the length + * also need be aligned to page size. + * Random segment-fault case like this: + * 0x7f2764ac5000 - 0x7f2764ac7000 + * |ptr0 128| ... |ptr1 4096| ... |ptr2 512| + * + * After ibv_reg_mr(pd, ptr1, 4096, access), the full range of 8K + * becomes DONTFORK. And the child process will hit a segment fault + * during access ptr0/ptr2. + * Note that the memory can be freed by libc free only. + * TODO: move it to zmalloc.c if necessary + */ +static void *page_aligned_zalloc(size_t size) { + void *tmp; + size_t aligned_size, page_size = sysconf(_SC_PAGESIZE); + + aligned_size = (size + page_size - 1) & (~(page_size - 1)); + if (posix_memalign(&tmp, page_size, aligned_size)) { + serverPanic("posix_memalign failed"); + } + + memset(tmp, 0x00, aligned_size); + + return tmp; +} + +static void rdmaDestroyIoBuf(RdmaContext *ctx) { + if (ctx->rx.mr) { + ibv_dereg_mr(ctx->rx.mr); + ctx->rx.mr = NULL; + } + + zlibc_free(ctx->rx.addr); + ctx->rx.addr = NULL; + + if (ctx->tx.mr) { + ibv_dereg_mr(ctx->tx.mr); + ctx->tx.mr = NULL; + } + + zlibc_free(ctx->tx.addr); + ctx->tx.addr = NULL; + + if (ctx->cmd_mr) { + ibv_dereg_mr(ctx->cmd_mr); + ctx->cmd_mr = NULL; + } + + zlibc_free(ctx->cmd_buf); + ctx->cmd_buf = NULL; +} + +static int rdmaSetupIoBuf(RdmaContext *ctx, struct rdma_cm_id *cm_id) { + int access = IBV_ACCESS_LOCAL_WRITE; + size_t length = sizeof(ValkeyRdmaCmd) * VALKEY_RDMA_MAX_WQE * 2; + ValkeyRdmaCmd *cmd; + int i; + + /* setup CMD buf & MR */ + ctx->cmd_buf = page_aligned_zalloc(length); + ctx->cmd_mr = ibv_reg_mr(ctx->pd, ctx->cmd_buf, length, access); + if (!ctx->cmd_mr) { + serverLog(LL_WARNING, "RDMA: reg mr for CMD failed"); + goto destroy_iobuf; + } + + for (i = 0; i < VALKEY_RDMA_MAX_WQE; i++) { + cmd = ctx->cmd_buf + i; + + if (rdmaPostRecv(ctx, cm_id, cmd) == C_ERR) { + serverLog(LL_WARNING, "RDMA: post recv failed"); + goto destroy_iobuf; + } + } + + for (i = VALKEY_RDMA_MAX_WQE; i < VALKEY_RDMA_MAX_WQE * 2; i++) { + cmd = ctx->cmd_buf + i; + cmd->keepalive.opcode = VALKEY_RDMA_INVALID_OPCODE; + } + + /* setup recv buf & MR */ + access = IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_WRITE; + length = valkey_rdma_rx_size; + ctx->rx.addr = page_aligned_zalloc(length); + ctx->rx.length = length; + ctx->rx.mr = ibv_reg_mr(ctx->pd, ctx->rx.addr, length, access); + if (!ctx->rx.mr) { + serverLog(LL_WARNING, "RDMA: reg mr for recv buffer failed"); + goto destroy_iobuf; + } + + return C_OK; + +destroy_iobuf: + rdmaDestroyIoBuf(ctx); + return C_ERR; +} + +static int rdmaCreateResource(RdmaContext *ctx, struct rdma_cm_id *cm_id) { + int ret = C_OK; + struct ibv_device_attr device_attr; + struct ibv_qp_init_attr init_attr; + struct ibv_comp_channel *comp_channel = NULL; + struct ibv_cq *cq = NULL; + struct ibv_pd *pd = NULL; + + if (ibv_query_device(cm_id->verbs, &device_attr)) { + serverLog(LL_WARNING, "RDMA: ibv ibv query device failed"); + return C_ERR; + } + + pd = ibv_alloc_pd(cm_id->verbs); + if (!pd) { + serverLog(LL_WARNING, "RDMA: ibv alloc pd failed"); + return C_ERR; + } + + ctx->pd = pd; + + comp_channel = ibv_create_comp_channel(cm_id->verbs); + if (!comp_channel) { + serverLog(LL_WARNING, "RDMA: ibv create comp channel failed"); + return C_ERR; + } + + ctx->comp_channel = comp_channel; + + cq = ibv_create_cq(cm_id->verbs, VALKEY_RDMA_MAX_WQE * 2, NULL, comp_channel, + valkey_rdma_comp_vector % cm_id->verbs->num_comp_vectors); + if (!cq) { + serverLog(LL_WARNING, "RDMA: ibv create cq failed"); + return C_ERR; + } + + ctx->cq = cq; + ibv_req_notify_cq(cq, 0); + + memset(&init_attr, 0, sizeof(init_attr)); + init_attr.cap.max_send_wr = VALKEY_RDMA_MAX_WQE; + init_attr.cap.max_recv_wr = VALKEY_RDMA_MAX_WQE; + init_attr.cap.max_send_sge = device_attr.max_sge; + init_attr.cap.max_recv_sge = 1; + init_attr.qp_type = IBV_QPT_RC; + init_attr.send_cq = cq; + init_attr.recv_cq = cq; + ret = rdma_create_qp(cm_id, pd, &init_attr); + if (ret) { + serverLog(LL_WARNING, "RDMA: create qp failed"); + return C_ERR; + } + + if (rdmaSetupIoBuf(ctx, cm_id)) { + return C_ERR; + } + + return C_OK; +} + +static void rdmaReleaseResource(RdmaContext *ctx) { + rdmaDestroyIoBuf(ctx); + + if (ctx->cq) { + ibv_destroy_cq(ctx->cq); + } + + if (ctx->comp_channel) { + ibv_destroy_comp_channel(ctx->comp_channel); + } + + if (ctx->pd) { + ibv_dealloc_pd(ctx->pd); + } +} + +static int rdmaAdjustSendbuf(RdmaContext *ctx, unsigned int length) { + int access = IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_WRITE; + + if (length == ctx->tx_length) { + return C_OK; + } + + /* try to free old MR & buffer */ + if (ctx->tx_length) { + ibv_dereg_mr(ctx->tx.mr); + zlibc_free(ctx->tx.addr); + ctx->tx_length = 0; + } + + /* create a new buffer & MR */ + ctx->tx.addr = page_aligned_zalloc(length); + ctx->tx_length = length; + ctx->tx.mr = ibv_reg_mr(ctx->pd, ctx->tx.addr, length, access); + if (!ctx->tx.mr) { + serverRdmaError(server.neterr, "RDMA: reg send mr failed"); + serverLog(LL_WARNING, "RDMA: FATAL error, recv corrupted cmd"); + zlibc_free(ctx->tx.addr); + ctx->tx.addr = NULL; + ctx->tx_length = 0; + return C_ERR; + } + + return C_OK; +} + +static int rdmaSendCommand(RdmaContext *ctx, struct rdma_cm_id *cm_id, ValkeyRdmaCmd *cmd) { + struct ibv_send_wr send_wr, *bad_wr; + struct ibv_sge sge; + ValkeyRdmaCmd *_cmd; + int i, ret; + + /* find an unused cmd buffer */ + for (i = VALKEY_RDMA_MAX_WQE; i < 2 * VALKEY_RDMA_MAX_WQE; i++) { + _cmd = ctx->cmd_buf + i; + if (_cmd->keepalive.opcode == VALKEY_RDMA_INVALID_OPCODE) { + break; + } + } + + assert(i < 2 * VALKEY_RDMA_MAX_WQE); + + memcpy(_cmd, cmd, sizeof(ValkeyRdmaCmd)); + sge.addr = (uint64_t)_cmd; + sge.length = sizeof(ValkeyRdmaCmd); + sge.lkey = ctx->cmd_mr->lkey; + + send_wr.sg_list = &sge; + send_wr.num_sge = 1; + send_wr.wr_id = (uint64_t)_cmd; + send_wr.opcode = IBV_WR_SEND; + send_wr.send_flags = IBV_SEND_SIGNALED; + send_wr.next = NULL; + ret = ibv_post_send(cm_id->qp, &send_wr, &bad_wr); + if (ret) { + serverLog(LL_WARNING, "RDMA: post send failed: %d", ret); + return C_ERR; + } + + return C_OK; +} + +static int connRdmaRegisterRx(RdmaContext *ctx, struct rdma_cm_id *cm_id) { + ValkeyRdmaCmd cmd; + + cmd.memory.opcode = htons(RegisterXferMemory); + cmd.memory.addr = htonu64((uint64_t)ctx->rx.addr); + cmd.memory.length = htonl(ctx->rx.length); + cmd.memory.key = htonl(ctx->rx.mr->rkey); + + ctx->rx.offset = 0; + ctx->rx.pos = 0; + + return rdmaSendCommand(ctx, cm_id, &cmd); +} + +static int connRdmaGetFeature(RdmaContext *ctx, struct rdma_cm_id *cm_id, ValkeyRdmaCmd *cmd) { + ValkeyRdmaCmd _cmd; + + _cmd.feature.opcode = htons(GetServerFeature); + _cmd.feature.select = cmd->feature.select; + _cmd.feature.features = htonu64(0); /* currently no feature support */ + + return rdmaSendCommand(ctx, cm_id, &_cmd); +} + +static int connRdmaSetFeature(RdmaContext *ctx, struct rdma_cm_id *cm_id, ValkeyRdmaCmd *cmd) { + UNUSED(ctx); + UNUSED(cm_id); + + /* currently no feature support */ + if (ntohu64(cmd->feature.features)) return C_ERR; + + return C_OK; +} + +static int rdmaHandleEstablished(struct rdma_cm_event *ev) { + struct rdma_cm_id *cm_id = ev->id; + RdmaContext *ctx = cm_id->context; + + connRdmaRegisterRx(ctx, cm_id); + + return C_OK; +} + +static int rdmaHandleDisconnect(struct rdma_cm_event *ev) { + struct rdma_cm_id *cm_id = ev->id; + RdmaContext *ctx = cm_id->context; + connection *conn = ctx->conn; + rdma_connection *rdma_conn = (rdma_connection *)conn; + + conn->state = CONN_STATE_CLOSED; + + /* we can't close connection now, let's mark this connection as closed state */ + listAddNodeTail(pending_list, conn); + rdma_conn->pending_list_node = listLast(pending_list); + + return C_OK; +} + +static int connRdmaHandleRecv(RdmaContext *ctx, struct rdma_cm_id *cm_id, ValkeyRdmaCmd *cmd, uint32_t byte_len) { + if (unlikely(byte_len != sizeof(ValkeyRdmaCmd))) { + serverLog(LL_WARNING, "RDMA: FATAL error, recv corrupted cmd"); + return C_ERR; + } + + switch (ntohs(cmd->keepalive.opcode)) { + case GetServerFeature: connRdmaGetFeature(ctx, cm_id, cmd); break; + + case SetClientFeature: connRdmaSetFeature(ctx, cm_id, cmd); break; + + case Keepalive: break; + + case RegisterXferMemory: + ctx->tx_addr = (char *)ntohu64(cmd->memory.addr); + ctx->tx.length = ntohl(cmd->memory.length); + ctx->tx_key = ntohl(cmd->memory.key); + ctx->tx.offset = 0; + rdmaAdjustSendbuf(ctx, ctx->tx.length); + break; + + default: serverLog(LL_WARNING, "RDMA: FATAL error, unknown cmd"); return C_ERR; + } + + return rdmaPostRecv(ctx, cm_id, cmd); +} + +static int connRdmaHandleSend(ValkeyRdmaCmd *cmd) { + /* clear cmd and mark this cmd has already sent */ + memset(cmd, 0x00, sizeof(*cmd)); + cmd->keepalive.opcode = VALKEY_RDMA_INVALID_OPCODE; + + return C_OK; +} + +static int connRdmaHandleRecvImm(RdmaContext *ctx, struct rdma_cm_id *cm_id, ValkeyRdmaCmd *cmd, uint32_t byte_len) { + assert(byte_len + ctx->rx.offset <= ctx->rx.length); + + ctx->rx.offset += byte_len; + + return rdmaPostRecv(ctx, cm_id, cmd); +} + +static int connRdmaHandleWrite(RdmaContext *ctx, uint32_t byte_len) { + UNUSED(ctx); + UNUSED(byte_len); + + return C_OK; +} + + +static int connRdmaHandleCq(rdma_connection *rdma_conn) { + struct rdma_cm_id *cm_id = rdma_conn->cm_id; + RdmaContext *ctx = cm_id->context; + struct ibv_cq *ev_cq = NULL; + void *ev_ctx = NULL; + struct ibv_wc wc = {0}; + ValkeyRdmaCmd *cmd; + int ret; + + if (ibv_get_cq_event(ctx->comp_channel, &ev_cq, &ev_ctx) < 0) { + if (errno != EAGAIN) { + serverLog(LL_WARNING, "RDMA: get CQ event error"); + return C_ERR; + } + } else if (ibv_req_notify_cq(ev_cq, 0)) { + serverLog(LL_WARNING, "RDMA: notify CQ error"); + return C_ERR; + } + +pollcq: + ret = ibv_poll_cq(ctx->cq, 1, &wc); + if (ret < 0) { + serverLog(LL_WARNING, "RDMA: poll recv CQ error"); + return C_ERR; + } else if (ret == 0) { + return C_OK; + } + + ibv_ack_cq_events(ctx->cq, 1); + + if (wc.status != IBV_WC_SUCCESS) { + if (rdma_conn->c.state == CONN_STATE_CONNECTED) { + serverLog(LL_WARNING, "RDMA: CQ handle error status: %s[0x%x], opcode : 0x%x", ibv_wc_status_str(wc.status), + wc.status, wc.opcode); + } + return C_ERR; + } + + switch (wc.opcode) { + case IBV_WC_RECV: + cmd = (ValkeyRdmaCmd *)wc.wr_id; + if (connRdmaHandleRecv(ctx, cm_id, cmd, wc.byte_len) == C_ERR) { + return C_ERR; + } + break; + + case IBV_WC_RECV_RDMA_WITH_IMM: + cmd = (ValkeyRdmaCmd *)wc.wr_id; + if (connRdmaHandleRecvImm(ctx, cm_id, cmd, ntohl(wc.imm_data)) == C_ERR) { + rdma_conn->c.state = CONN_STATE_ERROR; + return C_ERR; + } + + break; + case IBV_WC_RDMA_WRITE: + if (connRdmaHandleWrite(ctx, wc.byte_len) == C_ERR) { + return C_ERR; + } + + break; + + case IBV_WC_SEND: + cmd = (ValkeyRdmaCmd *)wc.wr_id; + if (connRdmaHandleSend(cmd) == C_ERR) { + return C_ERR; + } + + break; + + default: serverLog(LL_WARNING, "RDMA: unexpected opcode 0x[%x]", wc.opcode); return C_ERR; + } + + goto pollcq; +} + +static int connRdmaAccept(connection *conn, ConnectionCallbackFunc accept_handler) { + rdma_connection *rdma_conn = (rdma_connection *)conn; + struct rdma_cm_id *cm_id = rdma_conn->cm_id; + RdmaContext *ctx = cm_id->context; + struct ibv_device_attr device_attr; + int ret = C_OK; + + if (conn->state != CONN_STATE_ACCEPTING) return C_ERR; + + conn->state = CONN_STATE_CONNECTED; + + connIncrRefs(conn); + if (!callHandler(conn, accept_handler)) ret = C_ERR; + connDecrRefs(conn); + + if (ibv_query_device(cm_id->verbs, &device_attr)) { + serverLog(LL_WARNING, "RDMA: ibv ibv query device failed"); + return C_ERR; + } + + conn->iovcnt = min(device_attr.max_sge, IOV_MAX); + ctx->conn = conn; /* save conn into RdmaContext */ + + return ret; +} + +static connection *connCreateRdma(void) { + rdma_connection *rdma_conn = zcalloc(sizeof(rdma_connection)); + rdma_conn->c.type = &CT_RDMA; + rdma_conn->c.fd = -1; + rdma_conn->c.iovcnt = 1; /* at least 1, overwrite this on connect */ + + return (connection *)rdma_conn; +} + +static connection *connCreateAcceptedRdma(int fd, void *priv) { + rdma_connection *rdma_conn = (rdma_connection *)connCreateRdma(); + rdma_conn->c.fd = fd; + rdma_conn->c.state = CONN_STATE_ACCEPTING; + rdma_conn->cm_id = priv; + /* The comp channel fd should be always non block */ + connNonBlock(&rdma_conn->c); + + return (connection *)rdma_conn; +} + +static void connRdmaEventHandler(struct aeEventLoop *el, int fd, void *clientData, int mask) { + rdma_connection *rdma_conn = (rdma_connection *)clientData; + connection *conn = &rdma_conn->c; + struct rdma_cm_id *cm_id = rdma_conn->cm_id; + RdmaContext *ctx = cm_id->context; + int ret = 0; + + UNUSED(el); + UNUSED(fd); + UNUSED(mask); + + ret = connRdmaHandleCq(rdma_conn); + if (ret == C_ERR) { + conn->state = CONN_STATE_ERROR; + return; + } + + /* uplayer should read all */ + while (ctx->rx.pos < ctx->rx.offset) { + if (conn->read_handler && (callHandler(conn, conn->read_handler) == C_ERR)) { + return; + } + } + + /* recv buf is full, register a new RX buffer */ + if (ctx->rx.pos == ctx->rx.length) { + connRdmaRegisterRx(ctx, cm_id); + } + + /* RDMA comp channel has no POLLOUT event, try to send remaining buffer */ + if ((ctx->tx.offset < ctx->tx.length) && conn->write_handler) { + callHandler(conn, conn->write_handler); + } +} + +static int rdmaHandleConnect(char *err, struct rdma_cm_event *ev, char *ip, size_t ip_len, int *port) { + int ret = C_OK; + struct rdma_cm_id *cm_id = ev->id; + struct sockaddr_storage caddr; + RdmaContext *ctx = NULL; + struct rdma_conn_param conn_param = { + .responder_resources = 1, + .initiator_depth = 1, + .retry_count = 5, + }; + + memcpy(&caddr, &cm_id->route.addr.dst_addr, sizeof(caddr)); + if (caddr.ss_family == AF_INET) { + struct sockaddr_in *s = (struct sockaddr_in *)&caddr; + if (ip) inet_ntop(AF_INET, (void *)&(s->sin_addr), ip, ip_len); + if (port) *port = ntohs(s->sin_port); + } else { + struct sockaddr_in6 *s = (struct sockaddr_in6 *)&caddr; + if (ip) inet_ntop(AF_INET6, (void *)&(s->sin6_addr), ip, ip_len); + if (port) *port = ntohs(s->sin6_port); + } + + ctx = zcalloc(sizeof(RdmaContext)); + ctx->ip = zstrdup(ip); + ctx->port = *port; + cm_id->context = ctx; + if (rdmaCreateResource(ctx, cm_id) == C_ERR) { + goto reject; + } + + ret = rdma_accept(cm_id, &conn_param); + if (ret) { + serverRdmaError(err, "RDMA: accept failed"); + goto free_rdma; + } + + return C_OK; + +free_rdma: + rdmaReleaseResource(ctx); +reject: + /* reject connect request if hitting error */ + rdma_reject(cm_id, NULL, 0); + + return C_ERR; +} + +static rdma_listener *rdmaFdToListener(connListener *listener, int fd) { + for (int i = 0; i < listener->count; i++) { + if (listener->fd[i] != fd) continue; + + return (rdma_listener *)listener->priv + i; + } + + return NULL; +} + +/* + * rdmaAccept, actually it works as cm-event handler for listen cm_id. + * accept a connection logic works in two steps: + * 1, handle RDMA_CM_EVENT_CONNECT_REQUEST and return CM fd on success + * 2, handle RDMA_CM_EVENT_ESTABLISHED and return C_OK on success + */ +static int rdmaAccept(connListener *listener, char *err, int fd, char *ip, size_t ip_len, int *port, void **priv) { + struct rdma_cm_event *ev; + enum rdma_cm_event_type ev_type; + int ret = C_OK; + rdma_listener *rdma_listener; + + rdma_listener = rdmaFdToListener(listener, fd); + if (!rdma_listener) { + serverPanic("RDMA: unexpected listen file descriptor"); + } + + ret = rdma_get_cm_event(rdma_listener->cm_channel, &ev); + if (ret) { + if (errno != EAGAIN) { + serverLog(LL_WARNING, "RDMA: listen channel rdma_get_cm_event failed, %s", strerror(errno)); + return ANET_ERR; + } + return ANET_OK; + } + + ev_type = ev->event; + switch (ev_type) { + case RDMA_CM_EVENT_CONNECT_REQUEST: + ret = rdmaHandleConnect(err, ev, ip, ip_len, port); + if (ret == C_OK) { + RdmaContext *ctx = (RdmaContext *)ev->id->context; + *priv = ev->id; + ret = ctx->comp_channel->fd; + } + break; + + case RDMA_CM_EVENT_ESTABLISHED: ret = rdmaHandleEstablished(ev); break; + + case RDMA_CM_EVENT_UNREACHABLE: + case RDMA_CM_EVENT_ADDR_ERROR: + case RDMA_CM_EVENT_ROUTE_ERROR: + case RDMA_CM_EVENT_CONNECT_ERROR: + case RDMA_CM_EVENT_REJECTED: + case RDMA_CM_EVENT_ADDR_CHANGE: + case RDMA_CM_EVENT_DISCONNECTED: + case RDMA_CM_EVENT_TIMEWAIT_EXIT: + rdmaHandleDisconnect(ev); + ret = C_OK; + break; + + case RDMA_CM_EVENT_MULTICAST_JOIN: + case RDMA_CM_EVENT_MULTICAST_ERROR: + case RDMA_CM_EVENT_DEVICE_REMOVAL: + case RDMA_CM_EVENT_ADDR_RESOLVED: + case RDMA_CM_EVENT_ROUTE_RESOLVED: + case RDMA_CM_EVENT_CONNECT_RESPONSE: + default: serverLog(LL_NOTICE, "RDMA: listen channel ignore event: %s", rdma_event_str(ev_type)); break; + } + + if (rdma_ack_cm_event(ev)) { + serverLog(LL_WARNING, "ack cm event failed\n"); + return ANET_ERR; + } + + return ret; +} + +static void connRdmaAcceptHandler(aeEventLoop *el, int fd, void *privdata, int mask) { + int cport = 0, cfd, max = server.max_new_conns_per_cycle; + struct ClientFlags flags = {0}; + char cip[NET_IP_STR_LEN]; + void *connpriv = NULL; + connListener *listener = (connListener *)privdata; + UNUSED(el); + UNUSED(mask); + + while (max--) { + cfd = rdmaAccept(listener, server.neterr, fd, cip, sizeof(cip), &cport, &connpriv); + if (cfd == ANET_ERR) { + if (errno != EWOULDBLOCK) serverLog(LL_WARNING, "RDMA Accepting client connection: %s", server.neterr); + return; + } else if (cfd == ANET_OK) + continue; + + serverLog(LL_VERBOSE, "RDMA Accepted %s:%d", cip, cport); + acceptCommonHandler(connCreateAcceptedRdma(cfd, connpriv), flags, cip); + } +} + +static int connRdmaSetRwHandler(connection *conn) { + /* IB channel only has POLLIN event */ + if (conn->read_handler || conn->write_handler) { + if (aeCreateFileEvent(server.el, conn->fd, AE_READABLE, conn->type->ae_handler, conn) == AE_ERR) { + return C_ERR; + } + } else { + aeDeleteFileEvent(server.el, conn->fd, AE_READABLE); + } + + return C_OK; +} + +static int connRdmaSetWriteHandler(connection *conn, ConnectionCallbackFunc func, int barrier) { + rdma_connection *rdma_conn = (rdma_connection *)conn; + + if (conn->state != CONN_STATE_CONNECTED) { + return C_OK; + } + + conn->write_handler = func; + if (barrier) { + conn->flags |= CONN_FLAG_WRITE_BARRIER; + } else { + conn->flags &= ~CONN_FLAG_WRITE_BARRIER; + } + + /* does this connection has pending write data? */ + if (func) { + listAddNodeTail(pending_list, conn); + rdma_conn->pending_list_node = listLast(pending_list); + } else if (rdma_conn->pending_list_node) { + listDelNode(pending_list, rdma_conn->pending_list_node); + rdma_conn->pending_list_node = NULL; + } + + return connRdmaSetRwHandler(conn); +} + +static int connRdmaSetReadHandler(connection *conn, ConnectionCallbackFunc func) { + conn->read_handler = func; + + return connRdmaSetRwHandler(conn); +} + +static const char *connRdmaGetLastError(connection *conn) { + return strerror(conn->last_errno); +} + +static inline void rdmaConnectFailed(rdma_connection *rdma_conn) { + connection *conn = &rdma_conn->c; + + conn->state = CONN_STATE_ERROR; + conn->last_errno = ENETUNREACH; +} + +static int rdmaConnect(RdmaContext *ctx, struct rdma_cm_id *cm_id) { + struct rdma_conn_param conn_param = {0}; + + if (rdmaCreateResource(ctx, cm_id) == C_ERR) { + return C_ERR; + } + + /* rdma connect with param */ + conn_param.responder_resources = 1; + conn_param.initiator_depth = 1; + conn_param.retry_count = 7; + conn_param.rnr_retry_count = 7; + if (rdma_connect(cm_id, &conn_param)) { + return C_ERR; + } + + anetNonBlock(NULL, ctx->comp_channel->fd); + anetCloexec(ctx->comp_channel->fd); + + return C_OK; +} + +/* TODO: rdmaAccept also deals with RDMA event, server side has different logic with client side, maybe we can merge + * this CM logic in future */ +static void rdmaCMeventHandler(struct aeEventLoop *el, int fd, void *clientData, int mask) { + rdma_connection *rdma_conn = (rdma_connection *)clientData; + connection *conn = &rdma_conn->c; + struct rdma_cm_id *cm_id = rdma_conn->cm_id; + RdmaContext *ctx = cm_id->context; + struct rdma_event_channel *cm_channel = ctx->cm_channel; + struct rdma_cm_event *ev; + enum rdma_cm_event_type ev_type; + int ret = C_OK; + + UNUSED(el); + UNUSED(fd); + UNUSED(mask); + + ret = rdma_get_cm_event(cm_channel, &ev); + if (ret) { + if (errno != EAGAIN) { + serverLog(LL_WARNING, "RDMA: client channel rdma_get_cm_event failed, %s", strerror(errno)); + } + return; + } + + ev_type = ev->event; + switch (ev_type) { + case RDMA_CM_EVENT_ADDR_RESOLVED: + /* resolve route at most 100ms */ + if (rdma_resolve_route(ev->id, 100)) { + rdmaConnectFailed(rdma_conn); + } + break; + + case RDMA_CM_EVENT_ROUTE_RESOLVED: + if (rdmaConnect(ctx, ev->id) == C_ERR) { + rdmaConnectFailed(rdma_conn); + } + break; + + case RDMA_CM_EVENT_ESTABLISHED: + rdmaHandleEstablished(ev); + conn->state = CONN_STATE_CONNECTED; + conn->fd = ctx->comp_channel->fd; + if (conn->conn_handler) { + callHandler(conn, conn->conn_handler); + } + break; + + case RDMA_CM_EVENT_UNREACHABLE: + case RDMA_CM_EVENT_ROUTE_ERROR: + case RDMA_CM_EVENT_ADDR_ERROR: + case RDMA_CM_EVENT_REJECTED: rdmaConnectFailed(rdma_conn); break; + + case RDMA_CM_EVENT_CONNECT_ERROR: + case RDMA_CM_EVENT_TIMEWAIT_EXIT: + case RDMA_CM_EVENT_CONNECT_REQUEST: + case RDMA_CM_EVENT_ADDR_CHANGE: + case RDMA_CM_EVENT_DISCONNECTED: rdmaHandleDisconnect(ev); break; + + case RDMA_CM_EVENT_MULTICAST_JOIN: + case RDMA_CM_EVENT_MULTICAST_ERROR: + case RDMA_CM_EVENT_DEVICE_REMOVAL: + case RDMA_CM_EVENT_CONNECT_RESPONSE: + default: serverLog(LL_NOTICE, "RDMA: client channel ignore event: %s", rdma_event_str(ev_type)); + } + + if (rdma_ack_cm_event(ev)) { + serverLog(LL_NOTICE, "RDMA: ack cm event failed\n"); + } + + /* connection error or closed by remote peer */ + if (conn->state == CONN_STATE_ERROR) { + callHandler(conn, conn->conn_handler); + } +} + +/* free resource during connection close */ +static int rdmaResolveAddr(rdma_connection *rdma_conn, const char *addr, int port, const char *src_addr) { + struct addrinfo hints, *servinfo = NULL, *p = NULL; + struct rdma_event_channel *cm_channel = NULL; + struct rdma_cm_id *cm_id = NULL; + RdmaContext *ctx = NULL; + struct sockaddr_storage saddr; + char _port[6]; /* strlen("65535") */ + int availableAddrs = 0; + int ret = C_ERR; + + UNUSED(src_addr); + ctx = zcalloc(sizeof(RdmaContext)); + if (!ctx) { + serverLog(LL_WARNING, "RDMA: Out of memory"); + goto out; + } + + cm_channel = rdma_create_event_channel(); + if (!cm_channel) { + serverLog(LL_WARNING, "RDMA: create event channel failed"); + goto out; + } + ctx->cm_channel = cm_channel; + + if (rdma_create_id(cm_channel, &cm_id, (void *)ctx, RDMA_PS_TCP)) { + serverLog(LL_WARNING, "RDMA: create id failed"); + goto out; + } + rdma_conn->cm_id = cm_id; + + if (anetNonBlock(NULL, cm_channel->fd) != C_OK) { + serverLog(LL_WARNING, "RDMA: set cm channel fd non-block failed"); + goto out; + } + + snprintf(_port, 6, "%d", port); + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_INET; + hints.ai_socktype = SOCK_STREAM; + + if (getaddrinfo(addr, _port, &hints, &servinfo)) { + hints.ai_family = AF_INET6; + if (getaddrinfo(addr, _port, &hints, &servinfo)) { + serverLog(LL_WARNING, "RDMA: bad server addr info"); + goto out; + } + } + + for (p = servinfo; p != NULL; p = p->ai_next) { + if (p->ai_family == PF_INET) { + memcpy(&saddr, p->ai_addr, sizeof(struct sockaddr_in)); + ((struct sockaddr_in *)&saddr)->sin_port = htons(port); + } else if (p->ai_family == PF_INET6) { + memcpy(&saddr, p->ai_addr, sizeof(struct sockaddr_in6)); + ((struct sockaddr_in6 *)&saddr)->sin6_port = htons(port); + } else { + serverLog(LL_WARNING, "RDMA: Unsupported family"); + goto out; + } + + /* resolve addr at most 100ms */ + if (rdma_resolve_addr(cm_id, NULL, (struct sockaddr *)&saddr, 100)) { + continue; + } + availableAddrs++; + } + + if (!availableAddrs) { + serverLog(LL_WARNING, "RDMA: server addr not available"); + goto out; + } + + ret = C_OK; + +out: + if (servinfo) { + freeaddrinfo(servinfo); + } + + return ret; +} + +static int connRdmaWait(connection *conn, long start, long timeout) { + rdma_connection *rdma_conn = (rdma_connection *)conn; + long long remaining = timeout, wait, elapsed = 0; + + remaining = timeout - elapsed; + wait = (remaining < VALKEY_RDMA_SYNCIO_RES) ? remaining : VALKEY_RDMA_SYNCIO_RES; + aeWait(conn->fd, AE_READABLE, wait); + elapsed = mstime() - start; + if (elapsed >= timeout) { + errno = ETIMEDOUT; + return C_ERR; + } + + if (connRdmaHandleCq(rdma_conn) == C_ERR) { + conn->state = CONN_STATE_ERROR; + return C_ERR; + } + + return C_OK; +} + +static int connRdmaConnect(connection *conn, + const char *addr, + int port, + const char *src_addr, + ConnectionCallbackFunc connect_handler) { + rdma_connection *rdma_conn = (rdma_connection *)conn; + struct rdma_cm_id *cm_id; + RdmaContext *ctx; + + if (rdmaResolveAddr(rdma_conn, addr, port, src_addr) == C_ERR) { + return C_ERR; + } + + cm_id = rdma_conn->cm_id; + ctx = cm_id->context; + if (aeCreateFileEvent(server.el, ctx->cm_channel->fd, AE_READABLE, rdmaCMeventHandler, conn) == AE_ERR) { + return C_ERR; + } + + conn->conn_handler = connect_handler; + + return C_OK; +} + +static int connRdmaBlockingConnect(connection *conn, const char *addr, int port, long long timeout) { + rdma_connection *rdma_conn = (rdma_connection *)conn; + struct rdma_cm_id *cm_id; + RdmaContext *ctx; + long long start = mstime(); + + if (rdmaResolveAddr(rdma_conn, addr, port, NULL) == C_ERR) { + return C_ERR; + } + + cm_id = rdma_conn->cm_id; + ctx = cm_id->context; + if (aeCreateFileEvent(server.el, ctx->cm_channel->fd, AE_READABLE, rdmaCMeventHandler, conn) == AE_ERR) { + return C_ERR; + } + + do { + if (connRdmaWait(conn, start, timeout) == C_ERR) { + return C_ERR; + } + } while (conn->state != CONN_STATE_CONNECTED); + + return C_OK; +} + +static void connRdmaShutdown(connection *conn) { + UNUSED(conn); +} + +static void connRdmaClose(connection *conn) { + rdma_connection *rdma_conn = (rdma_connection *)conn; + struct rdma_cm_id *cm_id = rdma_conn->cm_id; + RdmaContext *ctx; + + if (conn->fd != -1) { + aeDeleteFileEvent(server.el, conn->fd, AE_READABLE); + conn->fd = -1; + } + + if (!cm_id) { + return; + } + + ctx = cm_id->context; + rdma_disconnect(cm_id); + + /* poll all CQ before close */ + connRdmaHandleCq(rdma_conn); + rdmaReleaseResource(ctx); + if (cm_id->qp) { + ibv_destroy_qp(cm_id->qp); + } + + rdma_destroy_id(cm_id); + if (ctx->cm_channel) { + aeDeleteFileEvent(server.el, ctx->cm_channel->fd, AE_READABLE); + rdma_destroy_event_channel(ctx->cm_channel); + } + + rdma_conn->cm_id = NULL; + zfree(ctx); + zfree(conn); +} + +static size_t connRdmaSend(connection *conn, const void *data, size_t data_len) { + rdma_connection *rdma_conn = (rdma_connection *)conn; + struct rdma_cm_id *cm_id = rdma_conn->cm_id; + RdmaContext *ctx = cm_id->context; + struct ibv_send_wr send_wr, *bad_wr; + struct ibv_sge sge; + uint32_t off = ctx->tx.offset; + char *addr = ctx->tx.addr + off; + char *remote_addr = ctx->tx_addr + ctx->tx.offset; + int ret; + + memcpy(addr, data, data_len); + + sge.addr = (uint64_t)addr; + sge.lkey = ctx->tx.mr->lkey; + sge.length = data_len; + + send_wr.sg_list = &sge; + send_wr.num_sge = 1; + send_wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM; + send_wr.send_flags = (++ctx->tx_ops % (VALKEY_RDMA_MAX_WQE / 2)) ? 0 : IBV_SEND_SIGNALED; + send_wr.imm_data = htonl(data_len); + send_wr.wr.rdma.remote_addr = (uint64_t)remote_addr; + send_wr.wr.rdma.rkey = ctx->tx_key; + send_wr.wr_id = 0; + send_wr.next = NULL; + ret = ibv_post_send(cm_id->qp, &send_wr, &bad_wr); + if (ret) { + serverLog(LL_WARNING, "RDMA: post send failed: %d", ret); + conn->state = CONN_STATE_ERROR; + return C_ERR; + } + + ctx->tx.offset += data_len; + + return data_len; +} + +static int connRdmaWrite(connection *conn, const void *data, size_t data_len) { + rdma_connection *rdma_conn = (rdma_connection *)conn; + struct rdma_cm_id *cm_id = rdma_conn->cm_id; + RdmaContext *ctx = cm_id->context; + uint32_t towrite; + + if (conn->state == CONN_STATE_ERROR || conn->state == CONN_STATE_CLOSED) { + return C_ERR; + } + + assert(ctx->tx.offset <= ctx->tx.length); + towrite = MIN(ctx->tx.length - ctx->tx.offset, data_len); + if (!towrite) { + return 0; + } + + return connRdmaSend(conn, data, towrite); +} + +static int connRdmaWritev(connection *conn, const struct iovec *iov, int iovcnt) { + int ret, nwritten = 0; + + for (int i = 0; i < iovcnt; i++) { + ret = connRdmaWrite(conn, iov[i].iov_base, iov[i].iov_len); + if (ret == C_ERR) return C_ERR; + nwritten += ret; + } + + return nwritten; +} + +static inline uint32_t rdmaRead(RdmaContext *ctx, void *buf, size_t buf_len) { + uint32_t toread; + + toread = MIN(ctx->rx.offset - ctx->rx.pos, buf_len); + + assert(ctx->rx.pos + toread <= ctx->rx.length); + memcpy(buf, ctx->rx.addr + ctx->rx.pos, toread); + + ctx->rx.pos += toread; + + return toread; +} + +static int connRdmaRead(connection *conn, void *buf, size_t buf_len) { + rdma_connection *rdma_conn = (rdma_connection *)conn; + struct rdma_cm_id *cm_id = rdma_conn->cm_id; + RdmaContext *ctx = cm_id->context; + + if (conn->state == CONN_STATE_ERROR || conn->state == CONN_STATE_CLOSED) { + return C_ERR; + } + + /* No more data to read */ + if (ctx->rx.pos == ctx->rx.offset) { + return -1; + } + + assert(ctx->rx.pos < ctx->rx.offset); + + return rdmaRead(ctx, buf, buf_len); +} + +static ssize_t connRdmaSyncWrite(connection *conn, char *ptr, ssize_t size, long long timeout) { + rdma_connection *rdma_conn = (rdma_connection *)conn; + struct rdma_cm_id *cm_id = rdma_conn->cm_id; + RdmaContext *ctx = cm_id->context; + ssize_t nwritten = 0; + long long start = mstime(); + uint32_t towrite; + + if (conn->state == CONN_STATE_ERROR || conn->state == CONN_STATE_CLOSED) { + return C_ERR; + } + + assert(ctx->tx.offset <= ctx->tx.length); + if (ctx->tx.offset < ctx->tx.length) { + /* TX buffer is available */ + goto copy; + } + +wait: + if (connRdmaWait(conn, start, timeout) == C_ERR) { + return C_ERR; + } + + if (unlikely(!ctx->tx.mr)) { + goto wait; + } + +copy: + towrite = MIN(ctx->tx.length - ctx->tx.offset, size - nwritten); + if (connRdmaSend(conn, ptr, towrite) == (size_t)C_ERR) { + return C_ERR; + } else { + ptr += towrite; + nwritten += towrite; + } + + if (nwritten < size) { + goto wait; + } + + return size; +} + +static ssize_t connRdmaSyncRead(connection *conn, char *ptr, ssize_t size, long long timeout) { + rdma_connection *rdma_conn = (rdma_connection *)conn; + struct rdma_cm_id *cm_id = rdma_conn->cm_id; + RdmaContext *ctx = cm_id->context; + ssize_t nread = 0; + long long start = mstime(); + uint32_t toread; + + if (conn->state == CONN_STATE_ERROR || conn->state == CONN_STATE_CLOSED) { + return C_ERR; + } + + assert(ctx->rx.pos <= ctx->rx.offset); + if (ctx->rx.pos < ctx->rx.offset) { + goto copy; + } + +wait: + if (connRdmaWait(conn, start, timeout) == C_ERR) { + return C_ERR; + } + +copy: + toread = rdmaRead(ctx, ptr, size - nread); + ptr += toread; + nread += toread; + if (nread < size) { + goto wait; + } + + return size; +} + +static ssize_t connRdmaSyncReadLine(connection *conn, char *ptr, ssize_t size, long long timeout) { + rdma_connection *rdma_conn = (rdma_connection *)conn; + struct rdma_cm_id *cm_id = rdma_conn->cm_id; + RdmaContext *ctx = cm_id->context; + ssize_t nread = 0; + long long start = mstime(); + uint32_t toread; + char *c; + char nl = 0; + + if (conn->state == CONN_STATE_ERROR || conn->state == CONN_STATE_CLOSED) { + return C_ERR; + } + + assert(ctx->rx.pos <= ctx->rx.offset); + if (ctx->rx.pos < ctx->rx.offset) { + goto copy; + } + +wait: + if (connRdmaWait(conn, start, timeout) == C_ERR) { + return C_ERR; + } + +copy: + for (toread = 0; toread <= ctx->rx.offset - ctx->rx.pos; toread++) { + c = ctx->rx.addr + ctx->rx.pos + toread; + if (*c == '\n') { + *c = '\0'; + if (toread && *(c - 1) == '\r') { + *(c - 1) = '\0'; + } + nl = 1; + break; + } + } + + toread = rdmaRead(ctx, ptr, MIN(toread + nl, size - nread)); + ptr += toread; + nread += toread; + if (nl) { + return nread; + } + + if (nread < size) { + goto wait; + } + + return size; +} + +static const char *connRdmaGetType(connection *conn) { + UNUSED(conn); + + return CONN_TYPE_RDMA; +} + +static int rdmaServer(char *err, int port, char *bindaddr, int af, rdma_listener *rdma_listener) { + int ret = ANET_OK, rv, afonly = 1; + char _port[6]; /* strlen("65535") */ + struct addrinfo hints, *servinfo, *p; + struct sockaddr_storage sock_addr; + struct rdma_cm_id *listen_cmid = NULL; + struct rdma_event_channel *listen_channel = NULL; + + snprintf(_port, 6, "%d", port); + memset(&hints, 0, sizeof(hints)); + hints.ai_family = af; + hints.ai_socktype = SOCK_STREAM; + hints.ai_flags = AI_PASSIVE; /* No effect if bindaddr != NULL */ + if (bindaddr && !strcmp("*", bindaddr)) bindaddr = NULL; + + if (af == AF_INET6 && bindaddr && !strcmp("::*", bindaddr)) bindaddr = NULL; + + if ((rv = getaddrinfo(bindaddr, _port, &hints, &servinfo)) != 0) { + serverRdmaError(err, "RDMA: %s", gai_strerror(rv)); + return ANET_ERR; + } else if (!servinfo) { + serverRdmaError(err, "RDMA: get addr info failed"); + ret = ANET_ERR; + goto end; + } + + listen_channel = rdma_create_event_channel(); + if (!listen_channel) { + serverLog(LL_WARNING, "RDMA: create event channel failed"); + goto error; + } + + for (p = servinfo; p != NULL; p = p->ai_next) { + memset(&sock_addr, 0, sizeof(sock_addr)); + if (p->ai_family == AF_INET6) { + memcpy(&sock_addr, p->ai_addr, sizeof(struct sockaddr_in6)); + ((struct sockaddr_in6 *)&sock_addr)->sin6_family = AF_INET6; + ((struct sockaddr_in6 *)&sock_addr)->sin6_port = htons(port); + } else { + memcpy(&sock_addr, p->ai_addr, sizeof(struct sockaddr_in)); + ((struct sockaddr_in *)&sock_addr)->sin_family = AF_INET; + ((struct sockaddr_in *)&sock_addr)->sin_port = htons(port); + } + + if (rdma_create_id(listen_channel, &listen_cmid, NULL, RDMA_PS_TCP)) { + serverRdmaError(err, "RDMA: create listen cm id error"); + return ANET_ERR; + } + + rdma_set_option(listen_cmid, RDMA_OPTION_ID, RDMA_OPTION_ID_AFONLY, &afonly, sizeof(afonly)); + + if (rdma_bind_addr(listen_cmid, (struct sockaddr *)&sock_addr)) { + serverRdmaError(err, "RDMA: bind addr error"); + goto error; + } + + if (rdma_listen(listen_cmid, 0)) { + serverRdmaError(err, "RDMA: listen addr error"); + goto error; + } + + rdma_listener->cm_id = listen_cmid; + rdma_listener->cm_channel = listen_channel; + goto end; + } + +error: + if (listen_cmid) rdma_destroy_id(listen_cmid); + if (listen_channel) rdma_destroy_event_channel(listen_channel); + ret = ANET_ERR; + +end: + freeaddrinfo(servinfo); + return ret; +} + +int connRdmaListen(connListener *listener) { + int j, ret; + char **bindaddr = listener->bindaddr; + int bindaddr_count = listener->bindaddr_count; + int port = listener->port; + char *default_bindaddr[2] = {"*", "-::*"}; + rdma_listener *rdma_listener; + + assert(server.proto_max_bulk_len <= 512ll * 1024 * 1024); + + /* Force binding of 0.0.0.0 if no bind address is specified. */ + if (listener->bindaddr_count == 0) { + bindaddr_count = 2; + bindaddr = default_bindaddr; + } + + listener->priv = rdma_listener = zcalloc_num(bindaddr_count, sizeof(*rdma_listener)); + for (j = 0; j < bindaddr_count; j++) { + char *addr = bindaddr[j]; + int optional = *addr == '-'; + + if (optional) addr++; + if (strchr(addr, ':')) { + /* Bind IPv6 address. */ + ret = rdmaServer(server.neterr, port, addr, AF_INET6, rdma_listener); + } else { + /* Bind IPv4 address. */ + ret = rdmaServer(server.neterr, port, addr, AF_INET, rdma_listener); + } + + if (ret == ANET_ERR) { + serverLog(LL_WARNING, "RDMA: Could not create server for %s:%d: %s", addr, port, server.neterr); + + return C_ERR; + } + + int fd = rdma_listener->cm_channel->fd; + anetNonBlock(NULL, fd); + anetCloexec(fd); + listener->fd[listener->count++] = fd; + rdma_listener++; + } + + return C_OK; +} + +static int connRdmaAddr(connection *conn, char *ip, size_t ip_len, int *port, int remote) { + rdma_connection *rdma_conn = (rdma_connection *)conn; + struct rdma_cm_id *cm_id = rdma_conn->cm_id; + struct sockaddr_storage *ss = NULL; + struct sockaddr_in *sa4; + struct sockaddr_in6 *sa6; + + if (remote) + ss = (struct sockaddr_storage *)rdma_get_peer_addr(cm_id); + else + ss = (struct sockaddr_storage *)rdma_get_local_addr(cm_id); + + if (!ss) { + goto error; + } + + if (ss->ss_family == AF_INET) { + sa4 = (struct sockaddr_in *)ss; + if (ip) { + if (inet_ntop(AF_INET, (void *)&(sa4->sin_addr), ip, ip_len) == NULL) { + goto error; + } + } + + if (port) { + *port = ntohs(sa4->sin_port); + } + } else if (ss->ss_family == AF_INET6) { + sa6 = (struct sockaddr_in6 *)ss; + if (ip) { + if (inet_ntop(AF_INET6, (void *)&(sa6->sin6_addr), ip, ip_len) == NULL) { + goto error; + } + } + + if (port) { + *port = ntohs(sa6->sin6_port); + } + } else { + /* TODO IB protocol */ + goto error; + } + + return 0; + +error: + if (ip) { + if (ip_len >= 2) { + ip[0] = '?'; + ip[1] = '\0'; + } else if (ip_len == 1) { + ip[0] = '\0'; + } + } + + if (port) *port = 0; + + return -1; +} + +static void rdmaInit(void) { + pending_list = listCreate(); + + VALKEY_BUILD_BUG_ON(sizeof(ValkeyRdmaFeature) != 32); + VALKEY_BUILD_BUG_ON(sizeof(ValkeyRdmaKeepalive) != 32); + VALKEY_BUILD_BUG_ON(sizeof(ValkeyRdmaMemory) != 32); + VALKEY_BUILD_BUG_ON(sizeof(ValkeyRdmaCmd) != 32); + + if (ibv_fork_init()) { + serverLog(LL_WARNING, "RDMA: FATAL error, ibv_fork_init failed"); + } +} + +static int rdmaHasPendingData(void) { + if (!pending_list) return 0; + + return listLength(pending_list) > 0; +} + +static int rdmaProcessPendingData(void) { + listIter li; + listNode *ln; + rdma_connection *rdma_conn; + connection *conn; + listNode *node; + int processed; + + processed = listLength(pending_list); + listRewind(pending_list, &li); + while ((ln = listNext(&li))) { + rdma_conn = listNodeValue(ln); + conn = &rdma_conn->c; + node = rdma_conn->pending_list_node; + + /* a connection can be disconnected by remote peer, CM event mark state as CONN_STATE_CLOSED, kick connection + * read/write handler to close connection */ + if (conn->state == CONN_STATE_ERROR || conn->state == CONN_STATE_CLOSED) { + listDelNode(pending_list, node); + /* do NOT call callHandler(conn, conn->read_handler) here, conn is freed in handler! */ + if (conn->read_handler) { + conn->read_handler(conn); + } else if (conn->write_handler) { + conn->write_handler(conn); + } + + continue; + } + + connRdmaEventHandler(NULL, -1, rdma_conn, 0); + } + + return processed; +} + +static ConnectionType CT_RDMA = { + /* connection type */ + .get_type = connRdmaGetType, + + /* connection type initialize & finalize & configure */ + .init = rdmaInit, + .cleanup = NULL, + + /* ae & accept & listen & error & address handler */ + .ae_handler = connRdmaEventHandler, + .accept_handler = connRdmaAcceptHandler, + //.cluster_accept_handler = NULL, + .listen = connRdmaListen, + .addr = connRdmaAddr, + + /* create/close connection */ + .conn_create = connCreateRdma, + .conn_create_accepted = connCreateAcceptedRdma, + .shutdown = connRdmaShutdown, + .close = connRdmaClose, + + /* connect & accept */ + .connect = connRdmaConnect, + .blocking_connect = connRdmaBlockingConnect, + .accept = connRdmaAccept, + + /* IO */ + .write = connRdmaWrite, + .writev = connRdmaWritev, + .read = connRdmaRead, + .set_write_handler = connRdmaSetWriteHandler, + .set_read_handler = connRdmaSetReadHandler, + .get_last_error = connRdmaGetLastError, + .sync_write = connRdmaSyncWrite, + .sync_read = connRdmaSyncRead, + .sync_readline = connRdmaSyncReadLine, + + /* pending data */ + .has_pending_data = rdmaHasPendingData, + .process_pending_data = rdmaProcessPendingData, +}; + +static struct connListener *rdmaListener(void) { + static struct connListener *listener = NULL; + + if (listener) return listener; + + listener = listenerByType(CONN_TYPE_RDMA); + serverAssert(listener != NULL); + + return listener; +} + +ConnectionType *connectionTypeRdma(void) { + static ConnectionType *ct_rdma = NULL; + + if (ct_rdma != NULL) return ct_rdma; + + ct_rdma = connectionByType(CONN_TYPE_RDMA); + serverAssert(ct_rdma != NULL); + + return ct_rdma; +} + +/* rdma listener has different create/close logic from TCP, we can't re-use 'int changeListener(connListener *listener)' + * directly */ +static int rdmaChangeListener(void) { + struct connListener *listener = rdmaListener(); + + /* Close old servers */ + for (int i = 0; i < listener->count; i++) { + if (listener->fd[i] == -1) continue; + + aeDeleteFileEvent(server.el, listener->fd[i], AE_READABLE); + listener->fd[i] = -1; + struct rdma_listener *rdma_listener = (struct rdma_listener *)listener->priv + i; + rdma_destroy_id(rdma_listener->cm_id); + rdma_destroy_event_channel(rdma_listener->cm_channel); + } + + listener->count = 0; + zfree(listener->priv); + + closeListener(listener); + + /* Just close the server if port disabled */ + if (listener->port == 0) { + if (server.set_proc_title) serverSetProcTitle(NULL); + return VALKEYMODULE_OK; + } + + /* Re-create listener */ + if (connListen(listener) != C_OK) { + return VALKEYMODULE_ERR; + } + + /* Create event handlers */ + if (createSocketAcceptHandler(listener, listener->ct->accept_handler) != C_OK) { + serverPanic("Unrecoverable error creating %s accept handler.", listener->ct->get_type(NULL)); + } + + if (server.set_proc_title) serverSetProcTitle(NULL); + + return VALKEYMODULE_OK; +} + +#ifdef BUILD_RDMA_MODULE + +#include "release.h" + +static long long rdmaGetPort(const char *name, void *privdata) { + UNUSED(name); + UNUSED(privdata); + struct connListener *listener = rdmaListener(); + + return listener->port; +} + +static int rdmaSetPort(const char *name, long long val, void *privdata, ValkeyModuleString **err) { + UNUSED(name); + UNUSED(privdata); + UNUSED(err); + struct connListener *listener = rdmaListener(); + listener->port = val; + + return VALKEYMODULE_OK; +} + +static ValkeyModuleString *rdma_bind; + +static void rdmaBuildBind(void *ctx) { + struct connListener *listener = rdmaListener(); + + if (rdma_bind) ValkeyModule_FreeString(NULL, rdma_bind); + + sds rdma_bind_str = sdsjoin(listener->bindaddr, listener->bindaddr_count, " "); + rdma_bind = ValkeyModule_CreateString(ctx, rdma_bind_str, sdslen(rdma_bind_str)); +} + +static ValkeyModuleString *rdmaGetBind(const char *name, void *privdata) { + UNUSED(name); + UNUSED(privdata); + + return rdma_bind; +} + +static int rdmaSetBind(const char *name, ValkeyModuleString *val, void *privdata, ValkeyModuleString **err) { + UNUSED(name); + UNUSED(err); + struct connListener *listener = rdmaListener(); + const char *bind = ValkeyModule_StringPtrLen(val, NULL); + int nexts; + sds *exts = sdssplitlen(bind, strlen(bind), " ", 1, &nexts); + + if (nexts > CONFIG_BINDADDR_MAX) { + serverLog(LL_WARNING, "RDMA: Unsupported bind ( > %d)", CONFIG_BINDADDR_MAX); + return VALKEYMODULE_ERR; + } + + /* Free old bind addresses */ + for (int j = 0; j < listener->bindaddr_count; j++) { + zfree(listener->bindaddr[j]); + } + + for (int j = 0; j < nexts; j++) listener->bindaddr[j] = zstrdup(exts[j]); + listener->bindaddr_count = nexts; + + sdsfreesplitres(exts, nexts); + rdmaBuildBind(privdata); + + return VALKEYMODULE_OK; +} + +static int rdmaApplyListener(ValkeyModuleCtx *ctx, void *privdata, ValkeyModuleString **err) { + UNUSED(ctx); + UNUSED(privdata); + UNUSED(err); + + return rdmaChangeListener(); +} + +static void rdmaListenerAddConfig(void *ctx) { + serverAssert(ValkeyModule_RegisterNumericConfig(ctx, "port", 0, VALKEYMODULE_CONFIG_DEFAULT, 0, 65535, rdmaGetPort, + rdmaSetPort, rdmaApplyListener, NULL) == VALKEYMODULE_OK); + serverAssert(ValkeyModule_RegisterStringConfig(ctx, "bind", "", VALKEYMODULE_CONFIG_DEFAULT, rdmaGetBind, + rdmaSetBind, rdmaApplyListener, ctx) == VALKEYMODULE_OK); + serverAssert(ValkeyModule_LoadConfigs(ctx) == VALKEYMODULE_OK); +} + +int ValkeyModule_OnLoad(void *ctx, ValkeyModuleString **argv, int argc) { + /* Connection modules MUST be part of the same build as valkey. */ + if (strcmp(REDIS_BUILD_ID_RAW, serverBuildIdRaw())) { + serverLog(LL_NOTICE, "Connection type %s was not built together with the valkey-server used.", CONN_TYPE_RDMA); + return VALKEYMODULE_ERR; + } + + if (ValkeyModule_Init(ctx, CONN_TYPE_RDMA, 1, VALKEYMODULE_APIVER_1) == VALKEYMODULE_ERR) return VALKEYMODULE_ERR; + + /* Connection modules is available only bootup. */ + if ((ValkeyModule_GetContextFlags(ctx) & VALKEYMODULE_CTX_FLAGS_SERVER_STARTUP) == 0) { + serverLog(LL_NOTICE, "Connection type %s can be loaded only during bootup", CONN_TYPE_RDMA); + return VALKEYMODULE_ERR; + } + + ValkeyModule_SetModuleOptions(ctx, VALKEYMODULE_OPTIONS_HANDLE_REPL_ASYNC_LOAD); + + if (connTypeRegister(&CT_RDMA) != C_OK) return VALKEYMODULE_ERR; + + rdmaListenerAddConfig(ctx); + + struct connListener *listener = rdmaListener(); + listener->ct = connectionTypeRdma(); + listener->bindaddr = zcalloc_num(CONFIG_BINDADDR_MAX, sizeof(listener->bindaddr[0])); + + for (int i = 0; i < argc; i++) { + robj *str = (robj *)argv[i]; + int nexts; + sds *exts = sdssplitlen(str->ptr, strlen(str->ptr), "=", 1, &nexts); + if (nexts != 2) { + serverLog(LL_WARNING, "RDMA: Unsupported argument \"%s\"", (char *)str->ptr); + return VALKEYMODULE_ERR; + } + + if (!strcasecmp(exts[0], "bind")) { + listener->bindaddr[listener->bindaddr_count++] = zstrdup(exts[1]); + } else if (!strcasecmp(exts[0], "port")) { + listener->port = atoi(exts[1]); + } else if (!strcasecmp(exts[0], "rx-size")) { + valkey_rdma_rx_size = atoi(exts[1]); + } else if (!strcasecmp(exts[0], "comp-vector")) { + valkey_rdma_comp_vector = atoi(exts[1]); + } else { + serverLog(LL_WARNING, "RDMA: Unsupported argument \"%s\"", (char *)str->ptr); + return VALKEYMODULE_ERR; + } + + sdsfreesplitres(exts, nexts); + } + + rdmaBuildBind(ctx); + if (valkey_rdma_comp_vector == -1) valkey_rdma_comp_vector = abs((int)random()); + + return VALKEYMODULE_OK; +} + +int ValkeyModule_OnUnload(void *arg) { + UNUSED(arg); + serverLog(LL_NOTICE, "Connection type %s can not be unloaded", CONN_TYPE_RDMA); + return VALKEYMODULE_ERR; +} + +#endif /* BUILD_RDMA_MODULE */ + +#endif /* USE_RDMA && __linux__ */ diff --git a/tests/rdma/Makefile b/tests/rdma/Makefile new file mode 100644 index 0000000000..aabf98b0b2 --- /dev/null +++ b/tests/rdma/Makefile @@ -0,0 +1,16 @@ +BIN = rdma-test + +ifeq ($(RDMA_PKGCONFIG),0) + RDMA_LIBS=$(shell $(PKG_CONFIG) --libs librdmacm libibverbs) +else + RDMA_LIBS=-lrdmacm -libverbs +endif + +$(BIN): rdma-test.c + @$(CC) $^ -o $@ $(RDMA_LIBS) -lpthread -g + @echo "\nHint: please check the RDMA environment:" + @echo "\t~# rdma res show" + @echo "\n Then launch valkey-server with RDMA support, Run ./"$(BIN) "to test ..." + +clean: + rm -rf $(BIN) diff --git a/tests/rdma/rdma-test.c b/tests/rdma/rdma-test.c new file mode 100644 index 0000000000..be86a5e4d6 --- /dev/null +++ b/tests/rdma/rdma-test.c @@ -0,0 +1,1059 @@ +/* ========================================================================== + * rdma-test.c - a simple test client for Valkey Over RDMA (Linux only) + * -------------------------------------------------------------------------- + * Copyright (C) 2021-2024 zhenwei pi + * + * This work is licensed under BSD 3-Clause, License 1 of the COPYING file in + * the top-level directory. + * ========================================================================== + */ + +#ifndef __linux__ /* currently RDMA is only supported on Linux */ + +#error "BUILD ERROR: RDMA is only supported on Linux" + +#else /* __linux__ */ + +#define _GNU_SOURCE +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +typedef struct valkeyRdmaFeature { + /* defined as following Opcodes */ + uint16_t opcode; + /* select features */ + uint16_t select; + uint8_t rsvd[20]; + /* feature bits */ + uint64_t features; +} valkeyRdmaFeature; + +typedef struct valkeyRdmaKeepalive { + /* defined as following Opcodes */ + uint16_t opcode; + uint8_t rsvd[30]; +} valkeyRdmaKeepalive; + +typedef struct valkeyRdmaMemory { + /* defined as following Opcodes */ + uint16_t opcode; + uint8_t rsvd[14]; + /* address of a transfer buffer which is used to receive remote streaming data, + * aka 'RX buffer address'. The remote side should use this as 'TX buffer address' */ + uint64_t addr; + /* length of the 'RX buffer' */ + uint32_t length; + /* the RDMA remote key of 'RX buffer' */ + uint32_t key; +} valkeyRdmaMemory; + +typedef union valkeyRdmaCmd { + valkeyRdmaFeature feature; + valkeyRdmaKeepalive keepalive; + valkeyRdmaMemory memory; +} valkeyRdmaCmd; + +typedef enum valkeyRdmaOpcode { + GetServerFeature = 0, + SetClientFeature = 1, + Keepalive = 2, + RegisterXferMemory = 3, +} valkeyRdmaOpcode; + +#define MAX_THREADS 32 +#define UNUSED(x) (void)(x) +#define MIN(a, b) (a) < (b) ? a : b +#define VALKEY_RDMA_MAX_WQE 1024 +#define VALKEY_RDMA_DEFAULT_RX_LEN (1024*1024) +#define VALKEY_RDMA_INVALID_OPCODE 0xffff + +typedef struct RdmaContext { + struct rdma_cm_id *cm_id; + struct rdma_event_channel *cm_channel; + struct ibv_comp_channel *comp_channel; + struct ibv_cq *cq; + struct ibv_pd *pd; + bool connected; + + /* TX */ + char *tx_addr; + uint32_t tx_length; + uint32_t tx_offset; + uint32_t tx_key; + char *send_buf; + uint32_t send_length; + uint32_t send_ops; + struct ibv_mr *send_mr; + + /* RX */ + uint32_t rx_offset; + char *recv_buf; + unsigned int recv_length; + unsigned int recv_offset; + struct ibv_mr *recv_mr; + + /* CMD 0 ~ VALKEY_RDMA_MAX_WQE for recv buffer + * VALKEY_RDMA_MAX_WQE ~ 2 * VALKEY_RDMA_MAX_WQE -1 for send buffer */ + valkeyRdmaCmd *cmd_buf; + struct ibv_mr *cmd_mr; +} RdmaContext; + +static int valkeySetFdBlocking(int fd, int blocking) { + int flags; + + flags = fcntl(fd, F_GETFL); + if (flags == -1) { + return -1; + } + + if (blocking) + flags &= ~O_NONBLOCK; + else + flags |= O_NONBLOCK; + + return fcntl(fd, F_SETFL, flags); +} + +#define rdmaFatal(msg) \ + do { \ + fprintf(stderr, "%s:%d %s\n", __func__, __LINE__, msg); \ + assert(0); \ + } while (0) + +static inline long valkeyNowMs(void) { + struct timeval tv; + + if (gettimeofday(&tv, NULL) < 0) + return -1; + + return tv.tv_sec * 1000 + tv.tv_usec / 1000; +} + +static int rdmaPostRecv(RdmaContext *ctx, struct rdma_cm_id *cm_id, valkeyRdmaCmd *cmd) { + struct ibv_sge sge; + size_t length = sizeof(valkeyRdmaCmd); + struct ibv_recv_wr recv_wr, *bad_wr; + + + sge.addr = (uint64_t)cmd; + sge.length = length; + sge.lkey = ctx->cmd_mr->lkey; + + recv_wr.wr_id = (uint64_t)cmd; + recv_wr.sg_list = &sge; + recv_wr.num_sge = 1; + recv_wr.next = NULL; + + if (ibv_post_recv(cm_id->qp, &recv_wr, &bad_wr)) { + return -1; + } + + return 0; +} + +static void rdmaDestroyIoBuf(RdmaContext *ctx) { + if (ctx->recv_mr) { + ibv_dereg_mr(ctx->recv_mr); + ctx->recv_mr = NULL; + } + + free(ctx->recv_buf); + ctx->recv_buf = NULL; + + if (ctx->send_mr) { + ibv_dereg_mr(ctx->send_mr); + ctx->send_mr = NULL; + } + + free(ctx->send_buf); + ctx->send_buf = NULL; + + if (ctx->cmd_mr) { + ibv_dereg_mr(ctx->cmd_mr); + ctx->cmd_mr = NULL; + } + + free(ctx->cmd_buf); + ctx->cmd_buf = NULL; +} + +static int rdmaSetupIoBuf(RdmaContext *ctx, struct rdma_cm_id *cm_id) { + int access = IBV_ACCESS_LOCAL_WRITE; + size_t length = sizeof(valkeyRdmaCmd) * VALKEY_RDMA_MAX_WQE * 2; + valkeyRdmaCmd *cmd; + int i; + + /* setup CMD buf & MR */ + ctx->cmd_buf = calloc(length, 1); + ctx->cmd_mr = ibv_reg_mr(ctx->pd, ctx->cmd_buf, length, access); + if (!ctx->cmd_mr) { + rdmaFatal("RDMA: reg recv mr failed"); + goto destroy_iobuf; + } + + for (i = 0; i < VALKEY_RDMA_MAX_WQE; i++) { + cmd = ctx->cmd_buf + i; + + if (rdmaPostRecv(ctx, cm_id, cmd) == -1) { + rdmaFatal("RDMA: post recv failed"); + goto destroy_iobuf; + } + } + + for (i = VALKEY_RDMA_MAX_WQE; i < VALKEY_RDMA_MAX_WQE * 2; i++) { + cmd = ctx->cmd_buf + i; + cmd->keepalive.opcode = VALKEY_RDMA_INVALID_OPCODE; + } + + /* setup recv buf & MR */ + access = IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_WRITE; + length = VALKEY_RDMA_DEFAULT_RX_LEN; + ctx->recv_buf = calloc(length, 1); + ctx->recv_length = length; + ctx->recv_mr = ibv_reg_mr(ctx->pd, ctx->recv_buf, length, access); + if (!ctx->recv_mr) { + rdmaFatal("RDMA: reg send mr failed"); + goto destroy_iobuf; + } + + return 0; + +destroy_iobuf: + rdmaDestroyIoBuf(ctx); + return -1; +} + +static int rdmaAdjustSendbuf(RdmaContext *ctx, unsigned int length) { + int access = IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_WRITE; + + if (length == ctx->send_length) { + return 0; + } + + /* try to free old MR & buffer */ + if (ctx->send_length) { + ibv_dereg_mr(ctx->send_mr); + free(ctx->send_buf); + ctx->send_length = 0; + } + + /* create a new buffer & MR */ + ctx->send_buf = calloc(length, 1); + ctx->send_length = length; + ctx->send_mr = ibv_reg_mr(ctx->pd, ctx->send_buf, length, access); + if (!ctx->send_mr) { + rdmaFatal("RDMA: reg send buf mr failed"); + free(ctx->send_buf); + ctx->send_buf = NULL; + ctx->send_length = 0; + return -1; + } + + return 0; +} + + +static int rdmaSendCommand(RdmaContext *ctx, struct rdma_cm_id *cm_id, valkeyRdmaCmd *cmd) { + struct ibv_send_wr send_wr, *bad_wr; + struct ibv_sge sge; + valkeyRdmaCmd *_cmd; + int i; + int ret; + + /* find an unused cmd buffer */ + for (i = VALKEY_RDMA_MAX_WQE; i < 2 * VALKEY_RDMA_MAX_WQE; i++) { + _cmd = ctx->cmd_buf + i; + if (_cmd->keepalive.opcode == VALKEY_RDMA_INVALID_OPCODE) { + break; + } + } + + assert(i < 2 * VALKEY_RDMA_MAX_WQE); + + memcpy(_cmd, cmd, sizeof(valkeyRdmaCmd)); + sge.addr = (uint64_t)_cmd; + sge.length = sizeof(valkeyRdmaCmd); + sge.lkey = ctx->cmd_mr->lkey; + + send_wr.sg_list = &sge; + send_wr.num_sge = 1; + send_wr.wr_id = (uint64_t)_cmd; + send_wr.opcode = IBV_WR_SEND; + send_wr.send_flags = IBV_SEND_SIGNALED; + send_wr.next = NULL; + ret = ibv_post_send(cm_id->qp, &send_wr, &bad_wr); + if (ret) { + return -1; + } + + return 0; +} + +static int connRdmaRegisterRx(RdmaContext *ctx, struct rdma_cm_id *cm_id) { + valkeyRdmaCmd cmd = { 0 }; + + cmd.memory.opcode = htons(RegisterXferMemory); + cmd.memory.addr = htobe64((uint64_t)ctx->recv_buf); + cmd.memory.length = htonl(ctx->recv_length); + cmd.memory.key = htonl(ctx->recv_mr->rkey); + + ctx->rx_offset = 0; + ctx->recv_offset = 0; + + return rdmaSendCommand(ctx, cm_id, &cmd); +} + +static int connRdmaHandleRecv(RdmaContext *ctx, struct rdma_cm_id *cm_id, valkeyRdmaCmd *cmd, uint32_t byte_len) { + if (byte_len != sizeof(valkeyRdmaCmd)) { + rdmaFatal("RDMA: FATAL error, recv corrupted cmd"); + return -1; + } + + switch (ntohs(cmd->keepalive.opcode)) { + case RegisterXferMemory: + ctx->tx_addr = (char *)be64toh(cmd->memory.addr); + ctx->tx_length = ntohl(cmd->memory.length); + ctx->tx_key = ntohl(cmd->memory.key); + ctx->tx_offset = 0; + rdmaAdjustSendbuf(ctx, ctx->tx_length); + break; + + case Keepalive: + break; + + default: + rdmaFatal("RDMA: FATAL error, unknown cmd"); + return -1; + } + + return rdmaPostRecv(ctx, cm_id, cmd); +} + +static int connRdmaHandleRecvImm(RdmaContext *ctx, struct rdma_cm_id *cm_id, valkeyRdmaCmd *cmd, uint32_t byte_len) { + assert(byte_len + ctx->rx_offset <= ctx->recv_length); + ctx->rx_offset += byte_len; + + return rdmaPostRecv(ctx, cm_id, cmd); +} + +static int connRdmaHandleSend(valkeyRdmaCmd *cmd) { + /* mark this cmd has already sent */ + memset(cmd, 0x00, sizeof(*cmd)); + cmd->keepalive.opcode = VALKEY_RDMA_INVALID_OPCODE; + + return 0; +} + +static int connRdmaHandleWrite(RdmaContext *ctx, uint32_t byte_len) { + UNUSED(ctx); + UNUSED(byte_len); + + return 0; +} + +static int connRdmaHandleCq(RdmaContext *ctx) { + struct rdma_cm_id *cm_id = ctx->cm_id; + struct ibv_cq *ev_cq = NULL; + void *ev_ctx = NULL; + struct ibv_wc wc = {0}; + valkeyRdmaCmd *cmd; + int ret; + + if (ibv_get_cq_event(ctx->comp_channel, &ev_cq, &ev_ctx) < 0) { + if (errno != EAGAIN) { + rdmaFatal("RDMA: get cq event failed"); + return -1; + } + } else if (ibv_req_notify_cq(ev_cq, 0)) { + rdmaFatal("RDMA: notify cq failed"); + return -1; + } + +pollcq: + ret = ibv_poll_cq(ctx->cq, 1, &wc); + if (ret < 0) { + rdmaFatal("RDMA: poll cq failed"); + return -1; + } else if (ret == 0) { + return 0; + } + + ibv_ack_cq_events(ctx->cq, 1); + + if (wc.status != IBV_WC_SUCCESS) { + rdmaFatal("RDMA: send/recv failed"); + return -1; + } + + switch (wc.opcode) { + case IBV_WC_RECV: + cmd = (valkeyRdmaCmd *)wc.wr_id; + if (connRdmaHandleRecv(ctx, cm_id, cmd, wc.byte_len) == -1) { + return -1; + } + + break; + + case IBV_WC_RECV_RDMA_WITH_IMM: + cmd = (valkeyRdmaCmd *)wc.wr_id; + if (connRdmaHandleRecvImm(ctx, cm_id, cmd, ntohl(wc.imm_data)) == -1) { + return -1; + } + + break; + case IBV_WC_RDMA_WRITE: + if (connRdmaHandleWrite(ctx, wc.byte_len) == -1) { + return -1; + } + + break; + case IBV_WC_SEND: + cmd = (valkeyRdmaCmd *)wc.wr_id; + if (connRdmaHandleSend(cmd) == -1) { + return -1; + } + + break; + default: + rdmaFatal("RDMA: unexpected opcode"); + return -1; + } + + goto pollcq; + + return 0; +} + +static ssize_t valkeyRdmaRead(RdmaContext *ctx, char *buf, size_t data_len) { + struct rdma_cm_id *cm_id = ctx->cm_id; + struct pollfd pfd; + long timed = 1000; + long start = valkeyNowMs(); + uint32_t toread, remained; + +copy: + if (ctx->recv_offset < ctx->rx_offset) { + remained = ctx->rx_offset - ctx->recv_offset; + toread = MIN(remained, data_len); + + memcpy(buf, ctx->recv_buf + ctx->recv_offset, toread); + ctx->recv_offset += toread; + + if (ctx->recv_offset == ctx->recv_length) { + connRdmaRegisterRx(ctx, cm_id); + } + + return toread; + } + +pollcq: + /* try to poll a CQ firstly */ + if (connRdmaHandleCq(ctx) == -1) { + return -1; + } + + if (ctx->recv_offset < ctx->rx_offset) { + goto copy; + } + + pfd.fd = ctx->comp_channel->fd; + pfd.events = POLLIN; + pfd.revents = 0; + if (poll(&pfd, 1, 1000) < 0) { + return -1; + } + + if ((valkeyNowMs() - start) < timed) { + goto pollcq; + } + + rdmaFatal("RDMA: read timeout"); + return -1; +} + +static ssize_t valkeyRdmaReadFull(RdmaContext *ctx, char *buf, size_t data_len) { + size_t inbytes = 0; + + do { + inbytes += valkeyRdmaRead(ctx, buf + inbytes, data_len - inbytes); + } while (inbytes < data_len); + + return data_len; +} + +static size_t connRdmaSend(RdmaContext *ctx, struct rdma_cm_id *cm_id, const void *data, size_t data_len) { + struct ibv_send_wr send_wr, *bad_wr; + struct ibv_sge sge; + uint32_t off = ctx->tx_offset; + char *addr = ctx->send_buf + off; + char *remote_addr = ctx->tx_addr + off; + int ret; + + assert(data_len <= ctx->tx_length); + memcpy(addr, data, data_len); + + sge.addr = (uint64_t)addr; + sge.lkey = ctx->send_mr->lkey; + sge.length = data_len; + + send_wr.sg_list = &sge; + send_wr.num_sge = 1; + send_wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM; + send_wr.send_flags = (++ctx->send_ops % VALKEY_RDMA_MAX_WQE) ? 0 : IBV_SEND_SIGNALED; + send_wr.imm_data = htonl(data_len); + send_wr.wr.rdma.remote_addr = (uint64_t)remote_addr; + send_wr.wr.rdma.rkey = ctx->tx_key; + send_wr.next = NULL; + ret = ibv_post_send(cm_id->qp, &send_wr, &bad_wr); + if (ret) { + return -1; + } + + ctx->tx_offset += data_len; + + return data_len; +} + +static ssize_t valkeyRdmaWrite(RdmaContext *ctx, char *buf, size_t data_len) { + struct rdma_cm_id *cm_id = ctx->cm_id; + struct pollfd pfd; + long timed = 1000; + long start = valkeyNowMs(); + uint32_t towrite, wrote = 0; + size_t ret; + + /* try to pollcq to */ + goto pollcq; + +waitcq: + pfd.fd = ctx->comp_channel->fd; + pfd.events = POLLIN; + pfd.revents = 0; + if (poll(&pfd, 1, 1) < 0) { + return -1; + } + +pollcq: + if (connRdmaHandleCq(ctx) == -1) { + return -1; + } + + assert(ctx->tx_offset <= ctx->tx_length); + if (ctx->tx_offset == ctx->tx_length) { + /* wait a new TX buffer */ + goto waitcq; + } + + towrite = MIN(ctx->tx_length - ctx->tx_offset, data_len - wrote); + ret = connRdmaSend(ctx, cm_id, buf + wrote, towrite); + if (ret == (size_t)-1) { + return -1; + } + + wrote += ret; + if (wrote == data_len) { + return data_len; + } + + if ((valkeyNowMs() - start) < timed) { + goto waitcq; + } + + rdmaFatal("RDMA: write timeout"); + + return -1; +} + +static void valkeyRdmaClose(RdmaContext *ctx) { + struct rdma_cm_id *cm_id = ctx->cm_id; + + connRdmaHandleCq(ctx); + rdma_disconnect(cm_id); + ibv_destroy_cq(ctx->cq); + rdmaDestroyIoBuf(ctx); + ibv_destroy_qp(cm_id->qp); + ibv_destroy_comp_channel(ctx->comp_channel); + ibv_dealloc_pd(ctx->pd); + rdma_destroy_id(cm_id); + + rdma_destroy_event_channel(ctx->cm_channel); +} + +static void valkeyRdmaFree(void *privctx) { + if (!privctx) + return; + + free(privctx); +} + +static int valkeyRdmaConnect(RdmaContext *ctx, struct rdma_cm_id *cm_id) { + struct ibv_comp_channel *comp_channel = NULL; + struct ibv_cq *cq = NULL; + struct ibv_pd *pd = NULL; + struct ibv_qp_init_attr init_attr = {0}; + struct rdma_conn_param conn_param = {0}; + + pd = ibv_alloc_pd(cm_id->verbs); + if (!pd) { + rdmaFatal("RDMA: alloc pd failed"); + goto error; + } + + comp_channel = ibv_create_comp_channel(cm_id->verbs); + if (!comp_channel) { + rdmaFatal("RDMA: alloc pd failed"); + goto error; + } + + if (valkeySetFdBlocking(comp_channel->fd, 0) != 0) { + rdmaFatal("RDMA: set recv comp channel fd non-block failed"); + goto error; + } + + cq = ibv_create_cq(cm_id->verbs, VALKEY_RDMA_MAX_WQE * 2, ctx, comp_channel, 0); + if (!cq) { + rdmaFatal("RDMA: create send cq failed"); + goto error; + } + + if (ibv_req_notify_cq(cq, 0)) { + rdmaFatal("RDMA: notify send cq failed"); + goto error; + } + + /* create qp with attr */ + init_attr.cap.max_send_wr = VALKEY_RDMA_MAX_WQE; + init_attr.cap.max_recv_wr = VALKEY_RDMA_MAX_WQE; + init_attr.cap.max_send_sge = 1; + init_attr.cap.max_recv_sge = 1; + init_attr.qp_type = IBV_QPT_RC; + init_attr.send_cq = cq; + init_attr.recv_cq = cq; + if (rdma_create_qp(cm_id, pd, &init_attr)) { + rdmaFatal("RDMA: create qp failed"); + goto error; + } + + ctx->cm_id = cm_id; + ctx->comp_channel = comp_channel; + ctx->cq = cq; + ctx->pd = pd; + + if (rdmaSetupIoBuf(ctx, cm_id) != 0) + goto free_qp; + + /* rdma connect with param */ + conn_param.responder_resources = 1; + conn_param.initiator_depth = 1; + conn_param.retry_count = 7; + conn_param.rnr_retry_count = 7; + if (rdma_connect(cm_id, &conn_param)) { + rdmaFatal("RDMA: connect failed"); + goto destroy_iobuf; + } + + return 0; + +destroy_iobuf: + rdmaDestroyIoBuf(ctx); +free_qp: + ibv_destroy_qp(cm_id->qp); +error: + if (cq) + ibv_destroy_cq(cq); + if (pd) + ibv_dealloc_pd(pd); + if (comp_channel) + ibv_destroy_comp_channel(comp_channel); + + return -1; +} + +static int valkeyRdmaEstablished(RdmaContext *ctx, struct rdma_cm_id *cm_id) { + + /* it's time to tell redis we have already connected */ + ctx->connected = true; + + return connRdmaRegisterRx(ctx, cm_id); +} + +static int valkeyRdmaCM(RdmaContext *ctx, int timeout) { + struct rdma_cm_event *event; + char errorstr[128]; + int ret = -1; + + while (rdma_get_cm_event(ctx->cm_channel, &event) == 0) { + /* printf("GET RDMA CM EVENT: %s\n", rdma_event_str(event->event)); */ + switch (event->event) { + case RDMA_CM_EVENT_ADDR_RESOLVED: + if (timeout < 0 || timeout > 100) + timeout = 100; /* at most 100ms to resolve route */ + ret = rdma_resolve_route(event->id, timeout); + if (ret) { + rdmaFatal("RDMA: route resolve failed"); + } + break; + case RDMA_CM_EVENT_ROUTE_RESOLVED: + ret = valkeyRdmaConnect(ctx, event->id); + break; + case RDMA_CM_EVENT_ESTABLISHED: + ret = valkeyRdmaEstablished(ctx, event->id); + break; + case RDMA_CM_EVENT_TIMEWAIT_EXIT: + ret = -1; + rdmaFatal("RDMA: connect timeout"); + break; + case RDMA_CM_EVENT_ADDR_ERROR: + case RDMA_CM_EVENT_ROUTE_ERROR: + case RDMA_CM_EVENT_CONNECT_ERROR: + case RDMA_CM_EVENT_UNREACHABLE: + case RDMA_CM_EVENT_REJECTED: + case RDMA_CM_EVENT_DISCONNECTED: + case RDMA_CM_EVENT_ADDR_CHANGE: + default: + snprintf(errorstr, sizeof(errorstr), "RDMA: connect failed - %s", rdma_event_str(event->event)); + rdmaFatal(errorstr); + ret = -1; + break; + } + + rdma_ack_cm_event(event); + } + + return ret; +} + +static int valkeyRdmaWaitConn(RdmaContext *ctx, long timeout) { + int timed; + struct pollfd pfd; + long now = valkeyNowMs(); + long start = now; + + while (now - start < timeout) { + timed = (int)(timeout - (now - start)); + + pfd.fd = ctx->cm_channel->fd; + pfd.events = POLLIN; + pfd.revents = 0; + if (poll(&pfd, 1, timed) < 0) { + return -1; + } + + if (valkeyRdmaCM(ctx, timed) == -1) { + return -1; + } + + if (ctx->connected) { + return 0; + } + + now = valkeyNowMs(); + } + + return -1; +} + +static RdmaContext *valkeyContextConnectRdma(const char *addr, int port, int timeout) { + int ret; + char _port[6]; /* strlen("65535"); */ + struct addrinfo hints, *servinfo = NULL, *p; + RdmaContext *ctx = NULL; + struct sockaddr_storage saddr; + long start = valkeyNowMs(), timed; + + snprintf(_port, 6, "%d", port); + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_INET; + hints.ai_socktype = SOCK_STREAM; + + if ((ret = getaddrinfo(addr, _port, &hints, &servinfo)) != 0) { + hints.ai_family = AF_INET6; + if ((ret = getaddrinfo(addr, _port, &hints, &servinfo)) != 0) { + rdmaFatal(gai_strerror(ret)); + return NULL; + } + } + + ctx = calloc(sizeof(RdmaContext), 1); + if (!ctx) { + rdmaFatal("Out of memory"); + goto free_rdma; + } + + ctx->cm_channel = rdma_create_event_channel(); + if (!ctx->cm_channel) { + rdmaFatal("RDMA: create event channel failed"); + goto free_rdma; + } + + if (rdma_create_id(ctx->cm_channel, &ctx->cm_id, (void *)ctx, RDMA_PS_TCP)) { + rdmaFatal("RDMA: create id failed"); + goto free_rdma; + } + + if ((valkeySetFdBlocking(ctx->cm_channel->fd, 0) != 0)) { + rdmaFatal("RDMA: set cm channel fd non-block failed"); + goto free_rdma; + } + + for (p = servinfo; p != NULL; p = p->ai_next) { + if (p->ai_family == PF_INET) { + memcpy(&saddr, p->ai_addr, sizeof(struct sockaddr_in)); + ((struct sockaddr_in *)&saddr)->sin_port = htons(port); + } else if (p->ai_family == PF_INET6) { + memcpy(&saddr, p->ai_addr, sizeof(struct sockaddr_in6)); + ((struct sockaddr_in6 *)&saddr)->sin6_port = htons(port); + } else { + rdmaFatal("RDMA: unsupported family"); + goto free_rdma; + } + + /* resolve addr as most 100ms */ + if (rdma_resolve_addr(ctx->cm_id, NULL, (struct sockaddr *)&saddr, 100)) { + continue; + } + + timed = timeout - (valkeyNowMs() - start); + if ((valkeyRdmaWaitConn(ctx, timed) == 0) && ctx->connected) { + ret = 0; + goto end; + } + } + + if ((!ctx->connected) && (p == NULL)) { + rdmaFatal("RDMA: resolve failed"); + } + +free_rdma: + if (ctx->cm_id) { + rdma_destroy_id(ctx->cm_id); + } + if (ctx->cm_channel) { + rdma_destroy_event_channel(ctx->cm_channel); + } + + if (ctx) { + free(ctx); + } + +end: + if(servinfo) { + freeaddrinfo(servinfo); + } + + return ctx; +} + +static int port = 6379; +static char *host = NULL; +static int minkeys = 128; +static int maxkeys = 8192; +static int keysize = 1024 + 1; /* for '\0' terminator */ + +struct test_kv_pair { + char key[32]; /* "THREAD01-000001" */ + char *value; +}; + +static void *test_routine(void *arg) { + pid_t tid = gettid(); + RdmaContext *ctx; + struct test_kv_pair *kv_pairs = NULL, *kv_pair; + int keys; + + ctx = valkeyContextConnectRdma(host, port, 1000); + if (!ctx) { + rdmaFatal("RDMA connect failed"); + } + + int bufsize = keysize + 128; + char *inbuf = malloc(bufsize); + char *outbuf = malloc(bufsize); + int inbytes, outbytes; + + /* # round 1, test PING */ + char *pingcmd = "*1\r\n$4\r\nPING\r\n"; + char *pingresp = "+PONG\r\n"; + + valkeyRdmaWrite(ctx, pingcmd, strlen(pingcmd)); + inbytes = valkeyRdmaReadFull(ctx, inbuf, strlen(pingresp)); + assert(!strncmp(pingresp, inbuf, inbytes)); + printf("Valkey Over RDMA test thread[%d] PING/PONG [OK]\n", tid); + + /* prepare random KV for SET/GET */ + keys = random() % (maxkeys - minkeys) + minkeys; + kv_pairs = calloc(sizeof(struct test_kv_pair), keys); + + for (int i = 0; i < keys; i++) { + kv_pair = &kv_pairs[i]; + snprintf(kv_pair->key, sizeof(kv_pair->key) - 1, "THREAD%02d-%06d", tid, i); + kv_pair->value = calloc(keysize, 1); + for (int k = 0; k < keysize - 1; k++) { + kv_pair->value[k] = 'A' + random() % 26; /* generate upper case string */ + } + } + printf("Valkey Over RDMA test thread[%d] prepare %d KVs [OK]\n", tid, keys); + + /* # round 2, test SET */ + char *okresp = "+OK\r\n"; + + for (int i = 0; i < keys; i++) { + kv_pair = &kv_pairs[i]; + /* build SET command */ + outbytes = sprintf(outbuf, "*3\r\n$3\r\nSET\r\n$%ld\r\n%s\r\n$%ld\r\n%s\r\n", + strlen(kv_pair->key), kv_pair->key, + strlen(kv_pair->value), kv_pair->value); + valkeyRdmaWrite(ctx, outbuf, outbytes); + inbytes = valkeyRdmaReadFull(ctx, inbuf, strlen(okresp)); + assert(!strncmp("+OK\r\n", inbuf, inbytes)); + } + printf("Valkey Over RDMA test thread[%d] SET %d KVs [OK]\n", tid, keys); + + /* # round 3, test BGSAVE, to avoid "-ERR Background save already", run BGSAVE only once */ + char *bgsavecmd = "*1\r\n$6\r\nBGSAVE\r\n"; + char *bgsaveresp = "+Background saving started\r\n"; + static int bgsaved; + + if (!__atomic_fetch_add(&bgsaved, 1, __ATOMIC_SEQ_CST)) { + valkeyRdmaWrite(ctx, bgsavecmd, strlen(bgsavecmd)); + inbytes = valkeyRdmaReadFull(ctx, inbuf, strlen(bgsaveresp)); + assert(!strncmp(bgsaveresp, inbuf, inbytes)); + printf("Valkey Over RDMA test thread[%d] BGSAVE [OK]\n", tid); + } + + /* # round 4, test GET. also verify all the value already set */ + char *getrespprex = "$1024\r\n"; + int getrespprexlen = strlen(getrespprex); + + for (int i = 0; i < keys; i++) { + kv_pair = &kv_pairs[i]; + /* build GET command */ + outbytes = sprintf(outbuf, "*2\r\n$3\r\nGET\r\n$%ld\r\n%s\r\n", + strlen(kv_pair->key), kv_pair->key); + valkeyRdmaWrite(ctx, outbuf, outbytes); + inbytes = valkeyRdmaReadFull(ctx, inbuf, getrespprexlen + strlen(kv_pair->value) + 2); + assert(!strncmp(getrespprex, inbuf, getrespprexlen)); + assert(!strncmp(kv_pair->value, inbuf + getrespprexlen, strlen(kv_pair->value))); + } + printf("Valkey Over RDMA test thread[%d] GET %d KVs [OK]\n", tid, keys); + + return NULL; +} + +void usage(char *proc) { + printf("%s usage:\n", proc); + printf("\t--help/-H\n"); + printf("\t--host/-h HOSTADDR\n"); + printf("\t--port/-p PORT\n"); + printf("\t--maxkeys/-M MAXKEYS\n"); + printf("\t--minkeys/-M MINKEYS\n"); + printf("\t--thread/-t THREADS\n"); +} + +int main(int argc, char *argv[]) +{ + int c, args; + int nr_threads = 0; + pthread_t threads[MAX_THREADS]; + + static struct option long_opts[] = { + { "help", no_argument, NULL, 'H' }, + { "host", required_argument, NULL, 'h' }, + { "port", required_argument, NULL, 'p' }, + { "maxkeys", required_argument, NULL, 'M' }, + { "minkeys", required_argument, NULL, 'm' }, + { "thread", required_argument, NULL, 't' }, + }; + static char *short_opts = "Hh:p:t:M:m:"; + + while (1) { + c = getopt_long(argc, argv, short_opts, long_opts, &args); + if (c == -1) { + break; + } + switch (c) { + case 'h': + host = optarg; + break; + + case 'p': + port = atoi(optarg); + if (port <= 0 || port > 65535) { + rdmaFatal("invalid port"); + } + break; + + case 't': + nr_threads = atoi(optarg); + if (nr_threads < 0 || nr_threads > MAX_THREADS) { + rdmaFatal("--threads/-t is expected as [0, 32]"); + } + break; + + case 'M': + maxkeys = atoi(optarg); + break; + + case 'm': + minkeys = atoi(optarg); + break; + + case 'H': + usage(argv[0]); + exit(0); + + default: + usage(argv[0]); + exit(-1); /* this is not considered as success, to avoid auto-test workaround */ + } + } + + if (!host) { + rdmaFatal("missing --host/-H"); + } + + if (minkeys > maxkeys) { + rdmaFatal("minkeys should less than maxkeys"); + } + + /* To make the test randomly */ + srandom(time(NULL) ^ getpid()); + + /* main thread mode */ + if (!nr_threads) { + printf("Test a single client in main thread ...\n"); + test_routine(NULL); + + return 0; + } + + /* multi threads mode */ + for (int i = 0; i < nr_threads; i++) { + assert(!pthread_create(&threads[i], NULL, test_routine, NULL)); + } + + for (int i = 0; i < nr_threads; i++) { + pthread_join(threads[i], NULL); + } + + printf("Valkey Over RDMA test [OK]\n"); + + return 0; +} + +#endif /* __linux__ */ diff --git a/tests/rdma/rdma_env.py b/tests/rdma/rdma_env.py new file mode 100755 index 0000000000..f85d7b1e89 --- /dev/null +++ b/tests/rdma/rdma_env.py @@ -0,0 +1,145 @@ +#!/usr/bin/python3 +""" +========================================================================== +run.py - script to setup/cleanup soft RDMA devices. + note that is script need root privilege. +-------------------------------------------------------------------------- +Copyright (C) 2024 zhenwei pi + +This work is licensed under BSD 3-Clause, License 1 of the COPYING file in +the top-level directory. +========================================================================== +""" +import os +import subprocess +import netifaces +import time +import argparse + + +def prepare_ib(): + cmd = "modprobe rdma_cm && modprobe udp_tunnel && modprobe ip6_udp_tunnel && modprobe ib_uverbs" + p = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE) + if p.wait(): + outs, _ = p.communicate() + print("Valkey Over RDMA probe modules of IB [FAILED]") + print("---------------\n" + outs.decode() + "---------------\n") + os._exit(1); + + print("Valkey Over RDMA probe modules of IB [OK]") + + +def prepare_rxe(interface): + # is there any builtin rdma_rxe.ko? + p = subprocess.Popen("modprobe rdma_rxe 2> /dev/null", shell=True, stdout=subprocess.PIPE) + if p.wait(): + valkeydir = os.path.dirname(os.path.abspath(__file__)) + "/../.." + rxedir = valkeydir + "/tests/rdma/rxe" + rxekmod = rxedir + "/rdma_rxe.ko" + print(rxedir) + print(rxekmod) + if not os.path.exists(rxekmod): + print("Neither kernel builtin nor out-of-tree rdma_rxe.ko found. Abort") + print("Please run the following commands to build out-of-tree RXE on Linux-6.5, then retry:") + print("\t~# mkdir -p " + rxedir) + print("\t~# git clone https://github.com/pizhenwei/rxe.git " + rxedir) + print("\t~# cd " + rxedir) + print("\t~# make") + os._exit(1); + + cmd = "insmod " + rxekmod + p = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE) + if p.wait(): + os._exit(1); + + print("Valkey Over RDMA install RXE [OK]") + + softrdma = "rxe_" + interface + cmd = "rdma link add " + softrdma + " type rxe netdev " + interface + p = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE) + if p.wait(): + outs, _ = p.communicate() + print("Valkey Over RDMA install RXE [FAILED]") + print("---------------\n" + outs.decode() + "---------------\n") + os._exit(1); + + print("Valkey Over RDMA add RXE device <%s> [OK]" % softrdma) + + +# find any IPv4 available networking interface +def find_iface(): + interfaces = netifaces.interfaces() + for interface in interfaces: + if interface == "lo": + continue + + addrs = netifaces.ifaddresses(interface) + if netifaces.AF_INET not in addrs: + continue + + return interface + + +def setup_rdma(driver, interface): + if interface == None: + interface = find_iface() + + prepare_ib() + if driver == "rxe": + prepare_rxe(interface) + else: + print("rxe is currently supported only") + os._exit(1); + + +# iterate /sys/class/infiniband, find any all virtual RDMA device, and remove them +def cleanup_rdma(): + # Ex, /sys/class/infiniband/mlx5_0 + # Ex, /sys/class/infiniband/rxe_eth0 + # Ex, /sys/class/infiniband/siw_eth0 + ibclass = "/sys/class/infiniband/" + try: + for dev in os.listdir(ibclass): + # Ex, /sys/class/infiniband/rxe_eth0/ports/1/gid_attrs/ndevs/0 + origpath = os.readlink(ibclass + dev) + if "virtual" in origpath: + subprocess.Popen("rdma link del " + dev, shell=True).wait() + print("Remove virtual RDMA device : " + dev + " [OK]") + except os.error: + return None + + # try to remove RXE driver from kernel, ignore error + subprocess.Popen("rmmod rdma_rxe 2> /dev/null", shell=True).wait() + + # try to remove SIW driver from kernel, ignore error + subprocess.Popen("rmmod rdma_siw 2> /dev/null", shell=True).wait() + + return None + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description = "Script to setup/cleanup soft RDMA devices, note that root privilege is required", + formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument("-o", "--operation", type=str, + help="[setup|cleanup] setup or cleanup soft RDMA environment") + parser.add_argument("-d", "--driver", type=str, default="rxe", + help="[rxe|siw] specify soft RDMA driver, rxe by default") + parser.add_argument("-i", "--interface", type=str, + help="[IFACE] network interface, auto-select any available interface by default") + args = parser.parse_args() + + # test UID. none-root user must stop on none RDMA platform, show some hints and exit. + if os.geteuid(): + print("You are not root privileged. Abort.") + print("Or you may setup RXE manually in root privileged by commands:") + print("\t~# modprobe rdma_rxe") + print("\t~# rdma link add rxe0 type rxe netdev [IFACE]") + os._exit(1); + + if args.operation == "cleanup": + cleanup_rdma() + elif args.operation == "setup": + setup_rdma(args.driver, args.interface) + + os._exit(0); diff --git a/tests/rdma/run.py b/tests/rdma/run.py new file mode 100755 index 0000000000..61cae6f41f --- /dev/null +++ b/tests/rdma/run.py @@ -0,0 +1,146 @@ +#!/usr/bin/python3 +""" +========================================================================== +run.py - script for test client for Valkey Over RDMA (Linux only) +-------------------------------------------------------------------------- +Copyright (C) 2024 zhenwei pi + +This work is licensed under BSD 3-Clause, License 1 of the COPYING file in +the top-level directory. +========================================================================== +""" +import os +import subprocess +import netifaces +import time +import argparse + +def build_program(): + valkeydir = os.path.dirname(os.path.abspath(__file__)) + "/../.." + cmd = "make -C " + valkeydir + "/tests/rdma" + p = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE) + if p.wait(): + print("Valkey Over RDMA build rdma-test [FAILED]") + return 1 + + print("Valkey Over RDMA build rdma-test program [OK]") + return 0 + + +# iterate /sys/class/infiniband, find any usable RDMA device, and return IPv4 address +def find_rdma_dev(): + # Ex, /sys/class/infiniband/mlx5_0 + # Ex, /sys/class/infiniband/rxe_eth0 + # Ex, /sys/class/infiniband/siw_eth0 + ibclass = "/sys/class/infiniband/" + try: + for dev in os.listdir(ibclass): + # Ex, /sys/class/infiniband/rxe_eth0/ports/1/gid_attrs/ndevs/0 + netdev = ibclass + dev + "/ports/1/gid_attrs/ndevs/0" + with open(netdev) as fp: + addrs = netifaces.ifaddresses(fp.readline().strip("\n")) + if netifaces.AF_INET not in addrs: + continue + ipaddr = addrs[netifaces.AF_INET][0]["addr"] + print("Valkey Over RDMA test prepare " + dev + " <" + ipaddr + "> [OK]") + return ipaddr + except os.error: + return None + + return None + + +def test_rdma(ipaddr): + valkeydir = os.path.dirname(os.path.abspath(__file__)) + "/../.." + retval = 0 + + # step 1, prepare test directory + tmpdir = valkeydir + "/tests/rdma/tmp" + subprocess.Popen("mkdir -p " + tmpdir, shell=True).wait() + + # step 2, start server + svrpath = valkeydir + "/src/valkey-server" + rdmapath = valkeydir + "/src/valkey-rdma.so" + svrcmd = [svrpath, "--port", "0", "--loglevel", "verbose", "--protected-mode", "no", + "--appendonly", "no", "--daemonize", "no", "--dir", valkeydir + "/tests/rdma/tmp", + "--loadmodule", rdmapath, "port=6379", "bind=" + ipaddr] + + svr = subprocess.Popen(svrcmd, shell=False, stdout=subprocess.PIPE) + try: + if svr.wait(1): + print("Valkey Over RDMA valkey-server runs less than 1s [FAILED]") + return 1 + except subprocess.TimeoutExpired as e: + print("Valkey Over RDMA valkey-server start [OK]") + pass + + # step 3, run test client + start = time.time() + clipath = valkeydir + "/tests/rdma/rdma-test" + clicmd = [clipath, "--thread", "4", "-h", ipaddr] + cli = subprocess.Popen(clicmd, shell=False, stdout=subprocess.PIPE) + if cli.wait(60): + outs, _ = cli.communicate() + print("Valkey Over RDMA test [FAILED]") + print("---------------\n" + outs.decode() + "---------------\n") + retval = 1 + else: + elapsed = time.time() - start + outs, _ = cli.communicate() + print("Valkey Over RDMA test in " + str(round(elapsed, 2)) + "s [OK]") + print(outs.decode()) + retval = 0 + + # step 4, cleanup + svr.kill() + svr.wait() + subprocess.Popen("rm -rf " + tmpdir, shell=True).wait() + + # step 5, report result + return retval + + +def test_exit(retval, install_rxe): + if install_rxe and not os.geteuid(): + rdma_env_py = os.path.dirname(os.path.abspath(__file__)) + "/rdma_env.py" + cmd = rdma_env_py + " -o cleanup" + subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE).wait() + + os._exit(retval); + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description = "Script to test Valkey Over RDMA", + formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument("-r", "--install-rxe", action='store_true', + help="install RXE driver and setup RXE device") + args = parser.parse_args() + + if args.install_rxe: + if os.geteuid(): + print("--install-rxe/-r must be root privileged") + test_exit(1, False) + + rdma_env_py = os.path.dirname(os.path.abspath(__file__)) + "/rdma_env.py" + cmd = rdma_env_py + " -o setup -d rxe" + p = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE) + if p.wait(): + print("Valkey Over RDMA setup RXE [FAILED]") + test_exit(1, False) + + # build C client into binary + retval = build_program() + if retval: + test_exit(1, args.install_rxe) + + ipaddr = find_rdma_dev() + if ipaddr is None: + # not fatal error, continue to create software version: RXE and SIW + print("Valkey Over RDMA test detect existing RDMA device [FAILED]") + else: + retval = test_rdma(ipaddr) + if not retval: + print("Valkey Over RDMA test over " + ipaddr + " [OK]") + + test_exit(0, args.install_rxe);