From fc88c3337adf9811548ae641329cc6f31e8dd49d Mon Sep 17 00:00:00 2001 From: Liu Dongmiao Date: Sat, 6 Apr 2024 06:53:53 +0800 Subject: [PATCH] common: fix and optimize option --- common.c | 36 +++++++++++++++++++++++ common.h | 5 +++- wss-proxy-client.c | 71 ++++++++++++++++++---------------------------- wss-proxy-server.c | 35 +++++++++-------------- 4 files changed, 81 insertions(+), 66 deletions(-) diff --git a/common.c b/common.c index 381af9f..0723f2c 100644 --- a/common.c +++ b/common.c @@ -145,6 +145,42 @@ enum log_level get_log_level() { return log_level; } +const char *find_option(const char *options, const char *key, const char *no_value) { + size_t len; + const char *pos, *value; + if (options == NULL) { + return NULL; + } + pos = options; + len = strlen(key); + while ((pos = strstr(pos, key)) != NULL) { + if (pos == options || *(pos - 1) == ';') { + value = pos + len; + if (*value == '=') { + return value + 1; + } else if (*value == ';' || *value == '\0') { + return no_value; + } + } + pos += len; + } + return NULL; +} + +int find_udp_port(int default_port) { + char *end; + const char *value; + value = find_option(getenv("SS_PLUGIN_OPTIONS"), "udp-port", NULL); + if (value != NULL) { + int port = (int) strtol(value, &end, 10); + if (port <= 0 || port > 65535 || (*end != '\0' && *end != ';')) { + port = -1; + } + return port; + } + return default_port; +} + static void on_native_signal(int signal) { if (signal == SIGINT) { LOGW("received interrupt, will exit"); diff --git a/common.h b/common.h index cf427d9..66b9c20 100644 --- a/common.h +++ b/common.h @@ -115,12 +115,15 @@ enum log_level get_log_level(void); } \ } while (0) - #define LOGD(format, ...) LOG(format, stdout, DEBUG, ## __VA_ARGS__) #define LOGI(format, ...) LOG(format, stdout, INFO, ## __VA_ARGS__) #define LOGW(format, ...) LOG(format, stderr, WARN, ## __VA_ARGS__) #define LOGE(format, ...) LOG(format, stderr, ERROR, ## __VA_ARGS__) +const char *find_option(const char *options, const char *key, const char *no_value); + +int find_udp_port(int default_port); + int init_event_signal(struct event_base *base, struct event **event_parent, struct event **event_sigquit); int is_websocket_key(const char *websocket_key); diff --git a/wss-proxy-client.c b/wss-proxy-client.c index c6a8487..e8bacfe 100644 --- a/wss-proxy-client.c +++ b/wss-proxy-client.c @@ -54,7 +54,7 @@ static int bufferevent_udp_cmp(const bufferevent_udp *a, const bufferevent_udp * static int init_raw_addr(struct sockaddr_storage *sockaddr, int *socklen, int *udp_port) { int port; - char *end, *options; + char *end; const char *local_host = getenv("SS_LOCAL_HOST"); const char *local_port = getenv("SS_LOCAL_PORT"); @@ -76,17 +76,7 @@ static int init_raw_addr(struct sockaddr_storage *sockaddr, int *socklen, int *u set_port(sockaddr, port); - *udp_port = port; - if ((options = getenv("SS_PLUGIN_OPTIONS")) != NULL && (options = strstr(options, "udp-port=")) != NULL) { - options += 9; - if ((end = strstr(options, ";")) != NULL) { - *end = '\0'; - } - *udp_port = (int) strtol(options, &end, 10); - if (*udp_port <= 0 || *udp_port > 65535 || *end != '\0') { - *udp_port = -1; - } - } + *udp_port = find_udp_port(port); if (*udp_port > 0) { LOGI("raw server tcp://%s:%d, udp://%s:%d", local_host, port, local_host, *udp_port); @@ -97,11 +87,9 @@ static int init_raw_addr(struct sockaddr_storage *sockaddr, int *socklen, int *u } static int init_wss_addr(struct wss_server_info *server) { - int port; - char *end; - int mux = 1; - char *wss; - const char *loglevel; + int port, mux; + char *end, *wss; + const char *value; const char *remote_host = getenv("SS_REMOTE_HOST"); const char *remote_port = getenv("SS_REMOTE_PORT"); const char *options = getenv("SS_PLUGIN_OPTIONS"); @@ -124,60 +112,51 @@ static int init_wss_addr(struct wss_server_info *server) { } server->port = port; - if (options == NULL) { - options = ""; - } - if (strchr(options, '\\') != NULL) { + if (options != NULL && strchr(options, '\\') != NULL) { LOGE("plugin options %s (contains \\) is unsupported", options); return EINVAL; } // host - server->host = strstr(options, "host="); + server->host = find_option(options, "host", NULL); if (server->host == NULL) { server->host = remote_host; - } else { - server->host += 5; } // path - server->path = strstr(options, "path="); + server->path = find_option(options, "path", NULL); if (server->path == NULL) { server->path = "/"; - } else { - server->path += 5; } // tls - if ((end = strstr(options, "tls")) != NULL) { - end += 3; - if (*end == '\0' || *end == ';') { - server->tls = 1; - } + if ((value = find_option(options, "tls", "1")) != NULL) { + server->tls = (int) strtol(value, NULL, 10); } // loglevel - if ((loglevel = strstr(options, "loglevel=")) != NULL) { - loglevel += 9; - init_log_level(loglevel); + if ((value = find_option(options, "loglevel", NULL)) != NULL) { + init_log_level(value); } // mux - if ((end = strstr(options, "mux=")) != NULL) { - end += 4; - mux = (int) strtol(end, NULL, 10); + if ((value = find_option(options, "mux", "1")) != NULL) { + mux = (int) strtol(value, NULL, 10); + } else { + mux = 1; } // wss - if ((end = strstr(options, "ws=")) != NULL) { - end += 3; - server->ws = (int) strtol(end, NULL, 10); + if ((value = find_option(options, "ws", "1")) != NULL) { + server->ws = (int) strtol(value, NULL, 10); } else { server->ws = 1; } // strip - if ((end = strstr(server->host, ";")) != NULL) { + server->host = strdup(server->host); + if ((end = strchr(server->host, ';')) != NULL) { *end = '\0'; } - if ((end = strstr(server->path, ";")) != NULL) { + server->path = strdup(server->path); + if ((end = strchr(server->path, ';')) != NULL) { *end = '\0'; } @@ -649,5 +628,11 @@ int main() { if (wss_context.ssl_ctx) { SSL_CTX_free(wss_context.ssl_ctx); } + if (wss_context.server.host) { + free((char *) wss_context.server.host); + } + if (wss_context.server.path) { + free((char *) wss_context.server.path); + } return code; } diff --git a/wss-proxy-server.c b/wss-proxy-server.c index 9c44594..bf15b85 100644 --- a/wss-proxy-server.c +++ b/wss-proxy-server.c @@ -20,12 +20,12 @@ struct raw_server_info { }; static int init_ws_info(const char **addr, int *port) { - int mux = 1; + int mux; char *end; + const char *value; const char *remote_host = getenv("SS_REMOTE_HOST"); const char *remote_port = getenv("SS_REMOTE_PORT"); const char *options = getenv("SS_PLUGIN_OPTIONS"); - const char *loglevel; if (remote_host != NULL && strchr(remote_host, '|') != NULL) { LOGE("remote host %s is not supported", remote_host); return EINVAL; @@ -36,18 +36,19 @@ static int init_ws_info(const char **addr, int *port) { LOGE("remote port %s is not supported", remote_port); return EINVAL; } - if (options == NULL) { - options = ""; + if (options != NULL && strchr(options, '\\') != NULL) { + LOGE("plugin options %s (contains \\) is unsupported", options); + return EINVAL; } // mux - if ((end = strstr(options, "mux=")) != NULL) { - end += 4; - mux = (int) strtol(end, NULL, 10); + if ((value = find_option(options, "mux", NULL)) != NULL) { + mux = (int) strtol(value, NULL, 10); + } else { + mux = 1; } // loglevel - if ((loglevel = strstr(options, "loglevel=")) != NULL) { - loglevel += 9; - init_log_level(loglevel); + if ((value = find_option(options, "loglevel", NULL)) != NULL) { + init_log_level(value); } LOGI("wss server %s:%d", *addr, *port); if (mux) { @@ -57,7 +58,7 @@ static int init_ws_info(const char **addr, int *port) { } static int init_raw_info(struct raw_server_info *raw_server_info) { - char *end, *options; + char *end; const char *local_host = getenv("SS_LOCAL_HOST"); const char *local_port = getenv("SS_LOCAL_PORT"); raw_server_info->addr = local_host == NULL ? "127.0.0.1" : local_host; @@ -71,17 +72,7 @@ static int init_raw_info(struct raw_server_info *raw_server_info) { return EINVAL; } - raw_server_info->udp_port = raw_server_info->port; - if ((options = getenv("SS_PLUGIN_OPTIONS")) != NULL && (options = strstr(options, "udp-port=")) != NULL) { - options += 9; - if ((end = strstr(options, ";")) != NULL) { - *end = '\0'; - } - raw_server_info->udp_port = (int) strtol(options, &end, 10); - if (raw_server_info->udp_port <= 0 || raw_server_info->udp_port > 65535 || *end != '\0') { - raw_server_info->udp_port = -1; - } - } + raw_server_info->udp_port = find_udp_port(raw_server_info->port); if (raw_server_info->udp_port > 0) { LOGI("raw client tcp://%s:%d, udp://%s:%d", raw_server_info->addr, raw_server_info->port,