diff --git a/common/include/tfe_cmsg.h b/common/include/tfe_cmsg.h index fcf09a5..31d7246 100644 --- a/common/include/tfe_cmsg.h +++ b/common/include/tfe_cmsg.h @@ -9,28 +9,34 @@ enum tfe_cmsg_errno { TFE_CMSG_INVALID_FORMAT = -1, TFE_CMSG_BUFF_NOT_ENOUGH = -2, - TFE_CMSG_INVALID_TYPE = -3 + TFE_CMSG_INVALID_TYPE = -3, }; enum tfe_cmsg_tlv_type { - TCP_RESTORE_INFO_TLV_SEQ = 0, - TCP_RESTORE_INFO_TLV_ACK, - TCP_RESTORE_INFO_TLV_MSS_CLIENT, - TCP_RESTORE_INFO_TLV_MSS_SERVER, - TCP_RESTORE_INFO_TLV_WSACLE_CLIENT, - TCP_RESTORE_INFO_TLV_WSACLE_SERVER, - TCP_RESTORE_INFO_TLV_SACK_CLIENT, - TCP_RESTORE_INFO_TLV_SACK_SERVER, - TCP_RESTORE_INFO_TLV_TS_CLIENT, - TCP_RESTORE_INFO_TLV_TS_SERVER, - TCP_RESTORE_INFO_TLV_USER_DEFINED + /* TCP restore information */ + TFE_CMSG_TCP_RESTORE_SEQ = 0x0, + TFE_CMSG_TCP_RESTORE_ACK = 0x1, + TFE_CMSG_TCP_RESTORE_MSS_CLIENT = 0x2, + TFE_CMSG_TCP_RESTORE_MSS_SERVER = 0x3, + TFE_CMSG_TCP_RESTORE_WSACLE_CLIENT = 0x4, + TFE_CMSG_TCP_RESTORE_WSACLE_SERVER = 0x5, + TFE_CMSG_TCP_RESTORE_SACK_CLIENT = 0x6, + TFE_CMSG_TCP_RESTORE_SACK_SERVER = 0x7, + TFE_CMSG_TCP_RESTORE_TS_CLIENT = 0x8, + TFE_CMSG_TCP_RESTORE_TS_SERVER = 0x9, + TFE_CMSG_TCP_RESTORE_PROTOCOL = 0xa, + + TFE_CMSG_POLICY_ID = 0x10, + TFE_CMSG_STREAM_TRACE_ID = 0x11, }; struct tfe_cmsg* tfe_cmsg_init(); void tfe_cmsg_destroy(struct tfe_cmsg *cmsg); -int tfe_cmsg_get(struct tfe_cmsg *cmsg, uint16_t type, uint16_t *size, unsigned char **pvalue); -int tfe_cmsg_set(struct tfe_cmsg *cmsg, uint16_t type, const unsigned char *value, uint16_t size); + +int tfe_cmsg_get_value(struct tfe_cmsg * cmsg, enum tfe_cmsg_tlv_type type, char * out_value, + size_t sz_out_value_buf, uint16_t * out_size); +int tfe_cmsg_set(struct tfe_cmsg * cmsg, enum tfe_cmsg_tlv_type type, const unsigned char * value, uint16_t size); uint16_t tfe_cmsg_serialize_size_get(struct tfe_cmsg *cmsg); int tfe_cmsg_serialize(struct tfe_cmsg *cmsg, unsigned char *buff, uint16_t bufflen, uint16_t *serialize_len); int tfe_cmsg_deserialize(const unsigned char *data, uint16_t len, struct tfe_cmsg** pcmsg); diff --git a/common/include/tfe_stream.h b/common/include/tfe_stream.h index 4b62791..97329ad 100644 --- a/common/include/tfe_stream.h +++ b/common/include/tfe_stream.h @@ -4,6 +4,7 @@ #include #include #include +#include enum tfe_stream_proto { @@ -98,6 +99,10 @@ int tfe_stream_shutdown(const struct tfe_stream * stream); int tfe_stream_shutdown_dir(const struct tfe_stream * stream, enum tfe_conn_dir dir); void tfe_stream_kill(const struct tfe_stream * stream); +/* stream's cmsg */ +struct tfe_cmsg * tfe_stream_get0_cmsg(const struct tfe_stream * stream); +void tfe_stream_cmsg_setup(const struct tfe_stream * stream, struct tfe_cmsg * cmsg); + /** * @brief Write linear text for given stream */ diff --git a/common/src/tfe_cmsg.cpp b/common/src/tfe_cmsg.cpp index fcd3683..3a3881a 100644 --- a/common/src/tfe_cmsg.cpp +++ b/common/src/tfe_cmsg.cpp @@ -2,6 +2,8 @@ #include #include #include +#include + #include "tfe_types.h" #include "tfe_utils.h" #include "tfe_cmsg.h" @@ -52,7 +54,7 @@ void tfe_cmsg_destroy(struct tfe_cmsg *cmsg) FREE(&cmsg); } -int tfe_cmsg_set(struct tfe_cmsg *cmsg, uint16_t type, const unsigned char *value, uint16_t size) +int tfe_cmsg_set(struct tfe_cmsg * cmsg, enum tfe_cmsg_tlv_type type, const unsigned char * value, uint16_t size) { if(type >= TFE_CMSG_TLV_NR_MAX) { @@ -73,17 +75,40 @@ int tfe_cmsg_set(struct tfe_cmsg *cmsg, uint16_t type, const unsigned char *valu return 0; } -int tfe_cmsg_get(struct tfe_cmsg *cmsg, uint16_t type, uint16_t *size, unsigned char **pvalue) +int tfe_cmsg_get_value(struct tfe_cmsg * cmsg, enum tfe_cmsg_tlv_type type, char * out_value, + size_t sz_out_value_buf, uint16_t * out_size) { - struct tfe_cmsg_tlv *tlv = NULL; - if(type >= TFE_CMSG_TLV_NR_MAX || (tlv = cmsg->tlvs[type]) == NULL) - { - *size = 0; - return TFE_CMSG_INVALID_TYPE; - } - *size = tlv->length - sizeof(struct tfe_cmsg_tlv); - *pvalue = tlv->value_as_string; - return 0; + struct tfe_cmsg_tlv *tlv; + int result = 0; + int value_length = 0; + + if (unlikely(type >= TFE_CMSG_TLV_NR_MAX)) + { + result = -EINVAL; + goto errout; + } + + tlv = cmsg->tlvs[type]; + if (unlikely(tlv == NULL)) + { + result = -ENOENT; + goto errout; + } + + value_length = tlv->length - sizeof(struct tfe_cmsg_tlv); + if (unlikely(sz_out_value_buf < value_length)) + { + result = -ENOBUFS; + goto errout; + } + + memcpy(out_value, tlv->value_as_string, value_length); + *out_size = value_length; + + return 0; + +errout: + return result; } uint16_t tfe_cmsg_serialize_size_get(struct tfe_cmsg *cmsg) @@ -160,6 +185,7 @@ int tfe_cmsg_deserialize(const unsigned char *data, uint16_t len, struct tfe_cms struct tfe_cmsg_serialize_header *header = (struct tfe_cmsg_serialize_header*)data; struct tfe_cmsg *cmsg = NULL; int offset = 0, nr_tlvs = -1; + if(len < sizeof(struct tfe_cmsg_serialize_header)) { goto error_out; @@ -168,6 +194,7 @@ int tfe_cmsg_deserialize(const unsigned char *data, uint16_t len, struct tfe_cms { goto error_out; } + cmsg = ALLOC(struct tfe_cmsg, 1); offset = sizeof(struct tfe_cmsg_serialize_header); nr_tlvs = ntohs(header->nr_tlvs); @@ -178,13 +205,18 @@ int tfe_cmsg_deserialize(const unsigned char *data, uint16_t len, struct tfe_cms { goto error_out; } + uint16_t type = ntohs(tlv->type); uint16_t length = ntohs(tlv->length); + if(length < sizeof(struct tfe_cmsg_tlv) || offset + length > len) { goto error_out; } - int ret = tfe_cmsg_set(cmsg, type, data + offset + sizeof(struct tfe_cmsg_tlv), length - sizeof(struct tfe_cmsg_tlv)); + + int ret = tfe_cmsg_set(cmsg, (enum tfe_cmsg_tlv_type)type, + data + offset + sizeof(struct tfe_cmsg_tlv), length - sizeof(struct tfe_cmsg_tlv)); + if(ret < 0) { goto error_out; diff --git a/common/test/test_cmsg.cpp b/common/test/test_cmsg.cpp index 2d02d58..a5a859f 100644 --- a/common/test/test_cmsg.cpp +++ b/common/test/test_cmsg.cpp @@ -8,19 +8,25 @@ #include "tfe_utils.h" #include "tfe_cmsg.h" +int main() +{ + return 0; +} + +/* int main(){ //init struct tfe_cmsg *cmsg = tfe_cmsg_init(); //set uint32_t value = 0x12345678; - int ret = tfe_cmsg_set(cmsg, TCP_RESTORE_INFO_TLV_SEQ, (const unsigned char*)(&value), 4); + int ret = tfe_cmsg_set(cmsg, TFE_CMSG_TCP_RESTORE_SEQ, (const unsigned char*)(&value), 4); printf("tfe_cmsg_set: ret is %d\n", ret); //get TCP_RESTORE_INFO_TLV_SEQ uint16_t size = -1; unsigned char *value1 = NULL; - ret = tfe_cmsg_get(cmsg, TCP_RESTORE_INFO_TLV_SEQ, &size, &value1); + ret = tfe_cmsg_get(cmsg, TFE_CMSG_TCP_RESTORE_SEQ, &size, &value1); printf("tfe_cmsg_get: ret is %d, type is TCP_RESTORE_INFO_TLV_SEQ, value is 0x%02x, value_size is %d\n", ret, ((uint32_t*)value1)[0], size); //get_serialize_size @@ -45,6 +51,7 @@ int main(){ //get TCP_RESTORE_INFO_TLV_SEQ size = -1; unsigned char *value2 = NULL; - ret = tfe_cmsg_get(cmsg1, TCP_RESTORE_INFO_TLV_SEQ, &size, &value2); + ret = tfe_cmsg_get(cmsg1, TFE_CMSG_TCP_RESTORE_SEQ, &size, &value2); printf("tfe_cmsg_get: ret is %d, type is TCP_RESTORE_INFO_TLV_SEQ, value is 0x%02x, value_size is %d\n", ret, ((uint32_t*)value2)[0], size); } +*/ diff --git a/platform/CMakeLists.txt b/platform/CMakeLists.txt index 68cdc64..b6a1d2c 100644 --- a/platform/CMakeLists.txt +++ b/platform/CMakeLists.txt @@ -1,7 +1,7 @@ -add_executable(tfe src/acceptor_scm.cpp src/ssl_stream.cpp - src/ssl_sess_cache.cpp src/ssl_sess_ticket.cpp src/ssl_service_cache.cpp - src/ssl_trusted_cert_storage.cpp src/ev_root_ca_metadata.cpp src/ssl_utils.cpp - src/tcp_stream.cpp src/main.cpp src/proxy.cpp) +add_executable(tfe src/acceptor_scm.cpp src/ssl_stream.cpp src/key_keeper.cpp + src/ssl_sess_cache.cpp src/ssl_sess_ticket.cpp src/ssl_service_cache.cpp + src/ssl_trusted_cert_storage.cpp src/ev_root_ca_metadata.cpp src/ssl_utils.cpp + src/tcp_stream.cpp src/main.cpp src/proxy.cpp) target_include_directories(tfe PUBLIC ${CMAKE_CURRENT_LIST_DIR}/include/external) target_include_directories(tfe PRIVATE ${CMAKE_CURRENT_LIST_DIR}/include/internal) diff --git a/platform/include/internal/platform.h b/platform/include/internal/platform.h index adb41b8..d85c91e 100644 --- a/platform/include/internal/platform.h +++ b/platform/include/internal/platform.h @@ -75,6 +75,7 @@ struct tfe_stream_private struct tfe_proxy * proxy_ref; struct tfe_thread_ctx * thread_ref; + struct tfe_cmsg * cmsg; enum tfe_stream_proto session_type; struct tfe_stream_write_ctx * w_ctx_upstream; diff --git a/platform/include/internal/proxy.h b/platform/include/internal/proxy.h index c69e4cd..4b06ddc 100644 --- a/platform/include/internal/proxy.h +++ b/platform/include/internal/proxy.h @@ -117,5 +117,5 @@ struct tfe_thread_ctx * tfe_proxy_thread_ctx_acquire(struct tfe_proxy * ctx); void tfe_proxy_thread_ctx_release(struct tfe_thread_ctx * thread_ctx); struct tfe_proxy * tfe_proxy_new(const char * profile); -int tfe_proxy_fds_accept(struct tfe_proxy * ctx, const struct tfe_proxy_accept_para * para); +int tfe_proxy_fds_accept(struct tfe_proxy * ctx, int fd_downstream, int fd_upstream, struct tfe_cmsg * cmsg); void tfe_proxy_run(struct tfe_proxy * proxy); diff --git a/platform/src/acceptor_scm.cpp b/platform/src/acceptor_scm.cpp index 7ff8f75..0a3cb5f 100644 --- a/platform/src/acceptor_scm.cpp +++ b/platform/src/acceptor_scm.cpp @@ -13,41 +13,16 @@ #include #include +#include #include #include #include +# #ifndef TFE_CONFIG_SCM_SOCKET_FILE #define TFE_CONFIG_SCM_SOCKET_FILE "/var/run/.tfe_kmod_scm_socket" #endif -/* The KNI and TFE communicate with each other by UNIX-based socket, - * and the protocol between them is based on TLV format(Type-Length-Value). - * The byte order for each entry in the protocol are Host-Ordered. - * - * I. Magic and Total counts of T-L-V tuples, at front of the SOCKET stream. - * II. After Magic header, the stream consist of several T-L-V tuples. - * - * Note. Magic = 0x4d5a - * Consider of the byte align problem, the minimum length of the value is 4bytes(32-bits). - * - * 0 1 2 3 - * 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | Magic | Total counts of TLV | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | Type | Length | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | Value | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | Type | Length | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | Value | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - * | ....... | - * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - */ - struct acceptor_scm { /* INPUT */ @@ -65,209 +40,12 @@ struct acceptor_scm pthread_t thread; }; -enum tcp_restore_info_tlv_type -{ - TCP_RESTORE_INFO_TLV_SEQ, - TCP_RESTORE_INFO_TLV_ACK, - TCP_RESTORE_INFO_TLV_MSS_CLIENT, - TCP_RESTORE_INFO_TLV_MSS_SERVER, - TCP_RESTORE_INFO_TLV_WSACLE_CLIENT, - TCP_RESTORE_INFO_TLV_WSACLE_SERVER, - TCP_RESTORE_INFO_TLV_SACK_CLIENT, - TCP_RESTORE_INFO_TLV_SACK_SERVER, - TCP_RESTORE_INFO_TLV_TS_CLIENT, - TCP_RESTORE_INFO_TLV_TS_SERVER, - TCP_RESTORE_INFO_TLV_PROTOCOL, - TCP_RESTORE_INFO_TLV_USER_DEFINED -}; - -struct tcp_restore_info_endpoint -{ - struct sockaddr_storage addr; - uint32_t seq; - uint32_t ack; - bool wscale_perm; - bool timestamp_perm; - bool sack_perm; - uint16_t mss; - uint8_t wscale; -}; - -struct tcp_restore_info -{ - struct tcp_restore_info_endpoint client; - struct tcp_restore_info_endpoint server; - unsigned int protocol; -}; - -struct tcp_restore_info_tlv -{ - uint16_t type; - uint16_t length; - - union - { - uint8_t value_as_uint8[0]; - uint16_t value_as_uint16[0]; - uint32_t value_as_uint32[0]; - unsigned char value_as_string[0]; - }; - -} __attribute__((packed)); - -struct tcp_restore_info_header -{ - uint8_t __magic__[2]; /* Must be 0x4d, 0x5a */ - uint16_t nr_tlvs; - struct tcp_restore_info_tlv tlvs[0]; -} __attribute__((packed)); - -#define TCP_RESTORE_TCPOPT_KIND 88 -#define TCP_RESTORE_TCPOPT_LENGTH 4 - -/* Copy from tfe-kmod */ -int tcp_restore_info_parse_from_cmsg(struct acceptor_scm * ctx, - const char * data, unsigned int datalen, struct tcp_restore_info * out) -{ - struct tcp_restore_info_header * header = (struct tcp_restore_info_header *)data; - unsigned int tlv_iter; - unsigned int nr_tlvs; - - if(unlikely(header->__magic__[0] != 0x4d || header->__magic__[1] != 0x5a)) - { - TFE_LOG_ERROR(ctx->logger, "Invalid restore info format: wrong magic, drop it.\n"); - goto invalid_format; - } - - nr_tlvs = ntohs(header->nr_tlvs); - if (unlikely(nr_tlvs >= 256)) - { - TFE_LOG_ERROR(ctx->logger, "Invalid restore info format: numbers of tlvs is larger than 256, drop it.\n"); - goto invalid_format; - } - - if (unlikely(datalen < sizeof(struct tcp_restore_info_header))) - { - TFE_LOG_ERROR(ctx->logger, "Invalid restore info format: length is shorter than tlv header, drop it.\n"); - goto invalid_format; - } - - datalen -= sizeof(struct tcp_restore_info_header); - data += sizeof(struct tcp_restore_info_header); - - for(tlv_iter = 0; tlv_iter < nr_tlvs; tlv_iter++) - { - struct tcp_restore_info_tlv * tlv = (struct tcp_restore_info_tlv *)data; - uint16_t tlv_type = ntohs(tlv->type); - uint16_t tlv_length = ntohs(tlv->length); - - unsigned int __length = tlv_length; - if(unlikely(datalen < __length)) - { - TFE_LOG_ERROR(ctx->logger, "Invalid restore info format: left space is smaller than tlv's length, " - "datalen is %u, tlv's length is %u, drop it.", datalen, __length); - goto invalid_format; - } - - if(unlikely(tlv_length < sizeof(uint16_t) * 2)) - { - TFE_LOG_ERROR(ctx->logger, "Invalid restore info format: invalid tlv length, " - "should larger than sizeof(type) + sizeof(length)."); - goto invalid_format; - } - - tlv_length -= sizeof(uint16_t) * 2; - -#define __CHECK_TLV_LENGTH(x) do { if(unlikely(x != tlv_length)) { \ - TFE_LOG_ERROR(ctx->logger, "Invalid restore format: invalid tlv length, should be %u, actually is %u, drop it.", \ - (unsigned int)x, (unsigned int)tlv_length); goto invalid_format; }} while(0) - - switch(tlv_type) - { - case TCP_RESTORE_INFO_TLV_SEQ: - __CHECK_TLV_LENGTH(sizeof(uint32_t)); - out->client.seq = ntohl(tlv->value_as_uint32[0]); - out->server.ack = ntohl(tlv->value_as_uint32[0]); - break; - - case TCP_RESTORE_INFO_TLV_ACK: - __CHECK_TLV_LENGTH(sizeof(uint32_t)); - out->client.ack = ntohl(tlv->value_as_uint32[0]); - out->server.seq = ntohl(tlv->value_as_uint32[0]); - break; - - case TCP_RESTORE_INFO_TLV_TS_CLIENT: - __CHECK_TLV_LENGTH(sizeof(uint8_t)); - out->client.timestamp_perm = !!(tlv->value_as_uint8[0]); - break; - - case TCP_RESTORE_INFO_TLV_TS_SERVER: - __CHECK_TLV_LENGTH(sizeof(uint8_t)); - out->server.timestamp_perm = !!(tlv->value_as_uint8[0]); - break; - - case TCP_RESTORE_INFO_TLV_WSACLE_CLIENT: - __CHECK_TLV_LENGTH(sizeof(uint8_t)); - out->client.wscale_perm = true; - out->client.wscale = tlv->value_as_uint8[0]; - break; - - case TCP_RESTORE_INFO_TLV_WSACLE_SERVER: - __CHECK_TLV_LENGTH(sizeof(uint8_t)); - out->server.wscale_perm = true; - out->server.wscale = tlv->value_as_uint8[0]; - break; - - case TCP_RESTORE_INFO_TLV_SACK_CLIENT: - __CHECK_TLV_LENGTH(sizeof(uint8_t)); - out->client.sack_perm = true; - break; - - case TCP_RESTORE_INFO_TLV_SACK_SERVER: - __CHECK_TLV_LENGTH(sizeof(uint8_t)); - out->server.sack_perm = true; - break; - - case TCP_RESTORE_INFO_TLV_MSS_CLIENT: - __CHECK_TLV_LENGTH(sizeof(uint16_t)); - out->client.mss = ntohs(tlv->value_as_uint16[0]); - break; - - case TCP_RESTORE_INFO_TLV_MSS_SERVER: - __CHECK_TLV_LENGTH(sizeof(uint16_t)); - out->server.mss = ntohs(tlv->value_as_uint16[0]); - break; - - case TCP_RESTORE_INFO_TLV_PROTOCOL: - __CHECK_TLV_LENGTH(sizeof(uint8_t)); - out->protocol = tlv->value_as_uint8[0]; - break; - - case TCP_RESTORE_INFO_TLV_USER_DEFINED: - break; - - default: - TFE_LOG_ERROR(ctx->logger, "Invalid restore info format: unsupported type %x, drop it.\n", tlv_type); - goto invalid_format; - } - - data += __length; - datalen -= __length; - } - - return 0; - -invalid_format: - return -EINVAL; -} - - void acceptor_scm_event(evutil_socket_t fd, short what, void * user) { struct acceptor_scm * __ctx = (struct acceptor_scm *) user; struct cmsghdr * __cmsghdr; struct tfe_proxy_accept_para __accept_para{}; - struct tcp_restore_info restore_info{}; + struct tfe_cmsg * cmsg = NULL; int * __fds = NULL; assert(__ctx != NULL && __ctx->thread == pthread_self()); @@ -322,18 +100,16 @@ void acceptor_scm_event(evutil_socket_t fd, short what, void * user) goto __die; } - if (tcp_restore_info_parse_from_cmsg(__ctx, __buffer, (size_t)rd, &restore_info) < 0) - { - TFE_LOG_ERROR(__ctx->logger, "Failed at parsing TLV format, drop the connection."); - goto __drop_recieved_fds; - } + /* Apply a cmsg structure */ + if (tfe_cmsg_deserialize((const unsigned char *)__buffer, (uint16_t)rd, &cmsg) < 0) + { + /* TODO: dump the buffer in hexdump format */ + TFE_LOG_ERROR(__ctx->logger, "failed at cmsg_deserialize(), invalid format."); + goto __drop_recieved_fds; + } - __accept_para.downstream_fd = __fds[0]; - __accept_para.upstream_fd = __fds[1]; - __accept_para.session_type = restore_info.protocol ? STREAM_PROTO_SSL : STREAM_PROTO_PLAIN; - - TFE_PROXY_STAT_INCREASE(STAT_FD_OPEN_BY_KNI_ACCEPT, 2); - if (tfe_proxy_fds_accept(__ctx->proxy, &__accept_para) < 0) + TFE_PROXY_STAT_INCREASE(STAT_FD_OPEN_BY_KNI_ACCEPT, 2); + if (tfe_proxy_fds_accept(__ctx->proxy, __fds[0], __fds[1], cmsg) < 0) { goto __drop_recieved_fds; } diff --git a/platform/src/proxy.cpp b/platform/src/proxy.cpp index cb3030b..69c7ac4 100644 --- a/platform/src/proxy.cpp +++ b/platform/src/proxy.cpp @@ -30,6 +30,8 @@ #include #include #include +#include +#include #include #include @@ -37,7 +39,6 @@ #include #include #include -#include extern struct ssl_policy_enforcer* ssl_policy_enforcer_create(void* logger); extern enum ssl_stream_action ssl_policy_enforce(struct ssl_stream *upstream, void* u_para); @@ -99,16 +100,28 @@ void tfe_proxy_thread_ctx_release(struct tfe_thread_ctx * thread_ctx) ATOMIC_DEC(&thread_ctx->load); } -int tfe_proxy_fds_accept(struct tfe_proxy * ctx, const struct tfe_proxy_accept_para * para) +int tfe_proxy_fds_accept(struct tfe_proxy * ctx, int fd_downstream, int fd_upstream, struct tfe_cmsg * cmsg) { - tfe_thread_ctx * worker_thread_ctx = tfe_proxy_thread_ctx_acquire(ctx); - + struct tfe_thread_ctx * worker_thread_ctx = tfe_proxy_thread_ctx_acquire(ctx); struct tfe_stream * stream = tfe_stream_create(ctx, worker_thread_ctx); - tfe_stream_option_set(stream, TFE_STREAM_OPT_SESSION_TYPE, ¶->session_type, sizeof(para->session_type)); - tfe_stream_option_set(stream, TFE_STREAM_OPT_KEYRING_ID, ¶->keyring_id, sizeof(para->keyring_id)); + + enum tfe_stream_proto stream_protocol; + uint16_t __size; + + int result = tfe_cmsg_get_value(cmsg, TFE_CMSG_TCP_RESTORE_PROTOCOL, (char *)&stream_protocol, + sizeof(stream_protocol), &__size); + + if (unlikely(result < 0)) + { + TFE_LOG_ERROR(ctx->logger, "failed at fetch connection's protocol from cmsg: %s", strerror(-result)); + goto __errout; + } + + tfe_stream_option_set(stream, TFE_STREAM_OPT_SESSION_TYPE, &stream_protocol, sizeof(stream_protocol)); + tfe_stream_cmsg_setup(stream, cmsg); /* FOR DEBUG */ - if (para->passthrough || ctx->tcp_all_passthrough) + if (unlikely(ctx->tcp_all_passthrough)) { bool __true = true; enum tfe_stream_proto __session_type = STREAM_PROTO_PLAIN; @@ -117,17 +130,16 @@ int tfe_proxy_fds_accept(struct tfe_proxy * ctx, const struct tfe_proxy_accept_p tfe_stream_option_set(stream, TFE_STREAM_OPT_SESSION_TYPE, &__session_type, sizeof(__session_type)); } - int ret = tfe_stream_init_by_fds(stream, para->downstream_fd, para->upstream_fd); - if (ret < 0) + result = tfe_stream_init_by_fds(stream, fd_downstream, fd_upstream); + if (result < 0) { TFE_LOG_ERROR(ctx->logger, "%p, Fds(downstream = %d, upstream = %d, type = %d) accept failed.", - stream, para->downstream_fd, para->upstream_fd, para->session_type); - goto __errout; + stream, fd_downstream, fd_upstream, stream_protocol); goto __errout; } else { TFE_LOG_DEBUG(ctx->logger, "%p, Fds(downstream = %d, upstream = %d, type = %d) accepted.", - stream, para->downstream_fd, para->upstream_fd, para->session_type); + stream, fd_downstream, fd_upstream, stream_protocol); } return 0; diff --git a/platform/src/ssl_stream.cpp b/platform/src/ssl_stream.cpp index b4446c2..d5256c3 100644 --- a/platform/src/ssl_stream.cpp +++ b/platform/src/ssl_stream.cpp @@ -820,7 +820,6 @@ static void ssl_async_peek_client_hello(struct future * f, evutil_socket_t fd, s struct peek_client_hello_ctx * ctx = ALLOC(struct peek_client_hello_ctx, 1); ctx->ev = event_new(evbase, fd, EV_READ, peek_client_hello_cb, p); ctx->evbase = evbase; - ctx->parse_client_cipher=parse_cipher; ctx->logger = logger; promise_set_ctx(p, (void *) ctx, peek_client_hello_ctx_free_cb); event_add(ctx->ev, NULL); diff --git a/platform/src/tcp_stream.cpp b/platform/src/tcp_stream.cpp index e47f339..59eaa38 100644 --- a/platform/src/tcp_stream.cpp +++ b/platform/src/tcp_stream.cpp @@ -1300,6 +1300,18 @@ int tfe_stream_option_set(struct tfe_stream * stream, enum tfe_stream_option opt return 0; } +struct tfe_cmsg * tfe_stream_get0_cmsg(const struct tfe_stream * stream) +{ + struct tfe_stream_private * _stream = container_of(stream, struct tfe_stream_private, head); + return _stream->cmsg; +} +void tfe_stream_cmsg_setup(const struct tfe_stream * stream, struct tfe_cmsg * cmsg) +{ + struct tfe_stream_private * _stream = container_of(stream, struct tfe_stream_private, head); + assert(_stream->cmsg == NULL); + _stream->cmsg = cmsg; +} + void tfe_stream_write_access_log(const struct tfe_stream * stream, int level, const char * fmt, ...) { va_list arg_ptr;