diff --git a/common/include/intercept_policy.h b/common/include/intercept_policy.h index 5fab9ff..4d81169 100644 --- a/common/include/intercept_policy.h +++ b/common/include/intercept_policy.h @@ -6,4 +6,7 @@ struct intercept_policy_enforcer *intercept_policy_enforcer_create(void *logger) void intercept_policy_enforce_destory(struct intercept_policy_enforcer *enforcer); // return 0 : success // return -1 : error (need passthrough) +int intercept_policy_select(struct intercept_policy_enforcer *enforcer, uint64_t *rule_id_array, int rule_id_num, uint64_t *selected_rule_id); +// return 0 : success +// return -1 : error (need passthrough) int intercept_policy_enforce(struct intercept_policy_enforcer *enforcer, struct tfe_cmsg *cmsg); \ No newline at end of file diff --git a/common/include/tfe_ctrl_packet.h b/common/include/tfe_ctrl_packet.h index 2c28f42..ab9b6c7 100644 --- a/common/include/tfe_ctrl_packet.h +++ b/common/include/tfe_ctrl_packet.h @@ -16,15 +16,17 @@ enum session_state SESSION_STATE_RESETALL = 4, }; +#define MAX_HIT_RULES 32 + struct ctrl_pkt_parser { char tsync[4]; uint64_t session_id; enum session_state state; char method[32]; - uint64_t tfe_policy_ids[32]; + uint64_t tfe_policy_ids[MAX_HIT_RULES]; int tfe_policy_id_num; - uint64_t sce_policy_ids[32]; + uint64_t sce_policy_ids[MAX_HIT_RULES]; int sce_policy_id_num; struct tfe_cmsg *cmsg; diff --git a/common/src/intercept_policy.cpp b/common/src/intercept_policy.cpp index a90e780..68888e5 100644 --- a/common/src/intercept_policy.cpp +++ b/common/src/intercept_policy.cpp @@ -244,6 +244,62 @@ void intercept_policy_enforce_destory(struct intercept_policy_enforcer *enforcer } } +// return 0 : success +// return -1 : error (need passthrough) +int intercept_policy_select(struct intercept_policy_enforcer *enforcer, uint64_t *rule_id_array, int rule_id_num, uint64_t *selected_rule_id) +{ + uint64_t rule_id = 0; + uint8_t is_hit_intercept_rule = 0; + uint8_t is_hit_no_intercept_rule = 0; + uint64_t max_intercept_rule_id = 0; + uint64_t max_no_intercept_rule_id = 0; + + char buff[16] = {0}; + struct intercept_param *param = NULL; + + for (int i = 0; i < rule_id_num; i++) + { + rule_id = rule_id_array[i]; + snprintf(buff, sizeof(buff), "%lu", rule_id); + param = (struct intercept_param *)maat_plugin_table_get_ex_data(enforcer->maat, enforcer->table_id, buff, strlen(buff)); + if (param == NULL) + { + TFE_LOG_INFO(enforcer->logger, "Failed to get intercept parameter of policy %lu.", rule_id); + continue; + } + + // intercept + if (param->action == 2) + { + is_hit_intercept_rule = 1; + max_intercept_rule_id = MAX(max_intercept_rule_id, rule_id); + TFE_LOG_INFO(enforcer->logger, "rule[%d/%d]: %lu is intercept.", i, rule_id_num, rule_id); + } + // not intercept + else + { + is_hit_no_intercept_rule = 1; + max_no_intercept_rule_id = MAX(max_no_intercept_rule_id, rule_id); + TFE_LOG_INFO(enforcer->logger, "rule[%d/%d]: %lu is no intercept.", i, rule_id_num, rule_id); + } + } + + if (is_hit_no_intercept_rule) + { + *selected_rule_id = max_no_intercept_rule_id; + return 0; + } + + if (is_hit_intercept_rule) + { + *selected_rule_id = max_intercept_rule_id; + return 0; + } + + // no policy get, passthrough + return -1; +} + // return 0 : success // return -1 : error (need passthrough) int intercept_policy_enforce(struct intercept_policy_enforcer *enforcer, struct tfe_cmsg *cmsg) diff --git a/common/src/tfe_ctrl_packet.cpp b/common/src/tfe_ctrl_packet.cpp index ae33f3f..c063e31 100644 --- a/common/src/tfe_ctrl_packet.cpp +++ b/common/src/tfe_ctrl_packet.cpp @@ -271,7 +271,6 @@ static int mpack_parse_array(struct ctrl_pkt_parser *handler, mpack_node_t node, static int proxy_parse_messagepack(mpack_node_t node, void *ctx, void *logger) { int ret = 0; - uint64_t tfe_policy_max_id = 0; struct ctrl_pkt_parser *handler = (struct ctrl_pkt_parser *)ctx; if (mpack_node_is_nil(mpack_node_map_cstr(node, "rule_ids"))) @@ -282,12 +281,6 @@ static int proxy_parse_messagepack(mpack_node_t node, void *ctx, void *logger) handler->tfe_policy_id_num = mpack_node_array_length(mpack_node_map_cstr(node, "rule_ids")); for (int i = 0; i < handler->tfe_policy_id_num; i++) { handler->tfe_policy_ids[i] = mpack_node_u64(mpack_node_array_at(mpack_node_map_cstr(node, "rule_ids"), i)); - if (tfe_policy_max_id < handler->tfe_policy_ids[i]) - tfe_policy_max_id = handler->tfe_policy_ids[i]; - } - - if (handler->tfe_policy_id_num) { - tfe_cmsg_set(handler->cmsg, TFE_CMSG_POLICY_ID, (const unsigned char *)&tfe_policy_max_id, sizeof(uint64_t)); } mpack_node_t tcp_handshake = mpack_node_map_cstr(node, "tcp_handshake"); diff --git a/common/src/tfe_packet_io.cpp b/common/src/tfe_packet_io.cpp index aae8f87..c510298 100644 --- a/common/src/tfe_packet_io.cpp +++ b/common/src/tfe_packet_io.cpp @@ -1024,6 +1024,16 @@ static int handle_session_opening(struct metadata *meta, struct ctrl_pkt_parser raw_packet_parser_get_most_inner_tuple4(&raw_parser, &inner_tuple4, logger); tfe_cmsg_get_value(parser->cmsg, TFE_CMSG_TCP_RESTORE_PROTOCOL, (unsigned char *)&stream_protocol_in_char, sizeof(stream_protocol_in_char), &size); + uint64_t rule_id = 0; + ret = intercept_policy_select(thread->ref_proxy->int_ply_enforcer, parser->tfe_policy_ids, parser->tfe_policy_id_num, &rule_id); + if (ret != 0) + { + is_passthrough = 1; + set_passthrough_reason(parser->cmsg, reason_invalid_intercept_param); + goto passthrough; + } + tfe_cmsg_set(parser->cmsg, TFE_CMSG_POLICY_ID, (const unsigned char *)&rule_id, sizeof(uint64_t)); + ret = intercept_policy_enforce(thread->ref_proxy->int_ply_enforcer, parser->cmsg); if (ret != 0) { is_passthrough = 1; @@ -1173,7 +1183,7 @@ passthrough: route_ctx_copy(&s_ctx->raw_meta_e2i->route_ctx, &parser->ack_route_ctx); } - TFE_LOG_INFO(logger, "%s: session %lu %s active first", LOG_TAG_PKTIO, s_ctx->session_id, s_ctx->session_addr); + TFE_LOG_INFO(logger, "%s: session %lu %s active first, hit rule %lu", LOG_TAG_PKTIO, s_ctx->session_id, s_ctx->session_addr, rule_id); session_table_insert(thread->session_table, s_ctx->session_id, &(s_ctx->c2s_info.tuple4), s_ctx, session_value_free_cb); ATOMIC_INC(&(packet_io_fs->session_num));