#include #include #include #include #include "list.h" #include "tcp_reassembly.h" #include "interval_tree.h" struct segment { struct interval_tree_node tree_node; struct list_head list_node; uint64_t time; uint64_t id; char *payload; // Flexible array member }; struct tcp_reassembly { uint8_t enable; uint32_t max_timeout; uint32_t max_segments; uint32_t max_bytes; struct tcp_reassembly_stat stat; struct rb_root_cached tree_root; struct list_head list_root; uint64_t rcv_nxt; // what we want to receive next }; /****************************************************************************** * Private API ******************************************************************************/ /* * The next routines deal with comparing 32 bit unsigned ints * and worry about wraparound (automatic with unsigned arithmetic). */ static inline bool before(uint32_t seq1, uint32_t seq2) { return (int32_t)(seq1 - seq2) < 0; } static int check_options(const struct tcp_reassembly_options *opts) { if (opts == NULL) { TCP_REASSEMBLE_ERROR("invalid options"); return -1; } if (opts->enable) { if (opts->max_timeout < 1 || opts->max_timeout > 60000) { TCP_REASSEMBLE_ERROR("invalid max_timeout: %u, supported range: [1, 60000]", opts->max_timeout); return -1; } } return 0; } /****************************************************************************** * Public API ******************************************************************************/ struct tcp_reassembly *tcp_reassembly_new(struct tcp_reassembly_options *opts) { if (check_options(opts) == -1) { return NULL; } struct tcp_reassembly *assy = (struct tcp_reassembly *)calloc(1, sizeof(struct tcp_reassembly)); if (assy == NULL) { return NULL; } assy->enable = opts->enable; assy->max_timeout = opts->max_timeout; assy->max_segments = opts->max_segments; assy->max_bytes = opts->max_bytes; assy->tree_root = RB_ROOT_CACHED; INIT_LIST_HEAD(&assy->list_root); return assy; } void tcp_reassembly_free(struct tcp_reassembly *assy) { struct segment *seg = NULL; struct interval_tree_node *tree_node = NULL; if (assy) { while ((tree_node = interval_tree_iter_first(&assy->tree_root, 0, UINT64_MAX))) { seg = container_of(tree_node, struct segment, tree_node); interval_tree_remove(&seg->tree_node, &assy->tree_root); list_del(&seg->list_node); free(seg); seg = NULL; } free(assy); assy = NULL; } } void tcp_reassembly_init(struct tcp_reassembly *assy, uint32_t syn_seq) { if (!assy->enable) { return; } assy->rcv_nxt = syn_seq + 1; TCP_REASSEMBLE_DEBUG("reassembler %p init expect seq %lu", assy, assy->rcv_nxt); } void tcp_reassembly_expire(struct tcp_reassembly *assy, uint64_t now) { if (!assy->enable) { return; } uint64_t len; struct segment *seg = NULL; while (!list_empty(&assy->list_root)) { seg = list_first_entry(&assy->list_root, struct segment, list_node); if (seg->time + assy->max_timeout > now) { break; } len = seg->tree_node.last - seg->tree_node.start + 1; assy->stat.timeout_discard_segments++; assy->stat.timeout_discard_bytes += len; assy->stat.curr_segments--; assy->stat.curr_bytes -= len; TCP_REASSEMBLE_DEBUG("reassembler %p expire segment %p [%lu, %lu] (time: %lu, now: %lu)", assy, seg, seg->tree_node.start, seg->tree_node.last, seg->time, now); interval_tree_remove(&seg->tree_node, &assy->tree_root); list_del(&seg->list_node); free(seg); seg = NULL; } } void tcp_reassembly_insert(struct tcp_reassembly *assy, uint32_t offset, const char *payload, uint32_t len, uint64_t now) { if (!assy->enable || len == 0) { return; } uint64_t low = (uint64_t)offset; uint64_t high = (uint64_t)offset + (uint64_t)len - 1; // from uint32_t to uint64_t, so no overflow assy->stat.insert_segments++; assy->stat.insert_bytes += len; if (assy->max_segments > 0 && assy->stat.curr_segments >= assy->max_segments) { assy->stat.overload_bypass_segments++; assy->stat.overload_bypass_bytes += len; TCP_REASSEMBLE_DEBUG("reassembler %p insert [%lu, %lu] failed, reach max packets %u", assy, low, high, assy->max_segments); return; } if (assy->max_bytes > 0 && assy->stat.curr_bytes >= assy->max_bytes) { assy->stat.overload_bypass_segments++; assy->stat.overload_bypass_bytes += len; TCP_REASSEMBLE_DEBUG("reassembler %p insert [%lu, %lu] failed, reach max bytes %u", assy, low, high, assy->max_bytes); return; } if (before(offset + len, assy->rcv_nxt)) { assy->stat.retrans_bypass_segments++; assy->stat.retrans_bypass_bytes += len; TCP_REASSEMBLE_DEBUG("reassembler %p insert [%lu, %lu] failed, less the expect seq %lu", assy, low, high, assy->rcv_nxt); return; } struct segment *seg = (struct segment *)calloc(1, sizeof(struct segment) + len); if (seg == NULL) { assy->stat.overload_bypass_segments++; assy->stat.overload_bypass_bytes += len; TCP_REASSEMBLE_DEBUG("reassembler %p insert [%lu, %lu] failed, calloc segment failed", assy, low, high); return; } seg->tree_node.start = low; seg->tree_node.last = high; seg->time = now; seg->id = assy->stat.insert_segments; seg->payload = (char *)seg + sizeof(struct segment); memcpy(seg->payload, payload, len); list_add_tail(&seg->list_node, &assy->list_root); interval_tree_insert(&seg->tree_node, &assy->tree_root); TCP_REASSEMBLE_DEBUG("reassembler %p insert segment %p [%lu, %lu]", assy, seg, low, high); assy->stat.curr_segments++; assy->stat.curr_bytes += len; } const char *tcp_reassembly_peek(struct tcp_reassembly *assy, uint32_t *len) { *len = 0; if (!assy->enable) { return NULL; } uint64_t id = UINT64_MAX; struct segment *seg = NULL; struct interval_tree_node *tree_node = NULL; struct interval_tree_node *oldest_node = NULL; tree_node = interval_tree_iter_first(&assy->tree_root, assy->rcv_nxt, assy->rcv_nxt); while (tree_node) { seg = container_of(tree_node, struct segment, tree_node); if (seg->id < id) { id = seg->id; oldest_node = tree_node; } tree_node = interval_tree_iter_next(tree_node, assy->rcv_nxt, assy->rcv_nxt); } if (oldest_node == NULL) { return NULL; } uint64_t payload_len = oldest_node->last - oldest_node->start + 1; seg = container_of(oldest_node, struct segment, tree_node); if (oldest_node->start < assy->rcv_nxt) { uint64_t overlap = assy->rcv_nxt - oldest_node->start; *len = (uint16_t)(payload_len - overlap); TCP_REASSEMBLE_DEBUG("reassembler %p peek [%lu, +∞], found segment %p [%lu, %lu] (left overlap: %lu)", assy, assy->rcv_nxt, seg, oldest_node->start, oldest_node->last, overlap); return seg->payload + overlap; } TCP_REASSEMBLE_DEBUG("reassembler %p peek [%lu, +∞], found segment %p [%lu, %lu]", assy, assy->rcv_nxt, seg, oldest_node->start, oldest_node->last); *len = (uint16_t)payload_len; return seg->payload; } void tcp_reassembly_consume(struct tcp_reassembly *assy, uint32_t len) { if (!assy->enable || len == 0) { return; } /* * https://www.ietf.org/rfc/rfc0793.txt * * This space ranges from 0 to 2**32 - 1. * Since the space is finite, all arithmetic dealing with sequence * numbers must be performed modulo 2**32. This unsigned arithmetic * preserves the relationship of sequence numbers as they cycle from * 2**32 - 1 to 0 again. There are some subtleties to computer modulo * arithmetic, so great care should be taken in programming the * comparison of such values. The symbol "=<" means "less than or equal" * (modulo 2**32). * * UINT32_MAX = 4294967295 * 2^32 = 4294967296 * 2^32 - 1 = 4294967295 * seq range: [0, 4294967295] * seq range: [0, UINT32_MAX] */ uint64_t old_exp_seq = assy->rcv_nxt; assy->rcv_nxt += len; if (assy->rcv_nxt > UINT32_MAX) { assy->rcv_nxt = assy->rcv_nxt % 4294967296; } uint64_t new_exp_seq = assy->rcv_nxt; TCP_REASSEMBLE_DEBUG("reassembler %p consume [%lu, %lu], update expect seq %lu -> %lu", assy, old_exp_seq, old_exp_seq + len - 1, old_exp_seq, new_exp_seq); assy->stat.consume_segments++; assy->stat.consume_bytes += len; struct interval_tree_node *node = interval_tree_iter_first(&assy->tree_root, old_exp_seq, old_exp_seq + len - 1); while (node) { if (before(node->last, new_exp_seq)) { struct segment *seg = container_of(node, struct segment, tree_node); uint32_t len = node->last - node->start + 1; assy->stat.remove_segments++; assy->stat.remove_bytes += len; assy->stat.curr_segments--; assy->stat.curr_bytes -= len; TCP_REASSEMBLE_DEBUG("reassembler %p consume [%lu, %lu], delete segment %p [%lu, %lu]", assy, old_exp_seq, old_exp_seq + len - 1, node, node->start, node->last); interval_tree_remove(node, &assy->tree_root); list_del(&seg->list_node); free(seg); node = interval_tree_iter_first(&assy->tree_root, old_exp_seq, old_exp_seq + len - 1); } else { node = interval_tree_iter_next(node, old_exp_seq, old_exp_seq + len - 1); } } } struct tcp_reassembly_stat *tcp_reassembly_get_stat(struct tcp_reassembly *assy) { if (!assy->enable) { return NULL; } return &assy->stat; } void tcp_reassembly_print_stat(struct tcp_reassembly *assy) { if (!assy->enable) { return; } TCP_REASSEMBLE_DEBUG("reassembler %p current : segments %lu, bytes %lu", assy, assy->stat.curr_segments, assy->stat.curr_bytes); TCP_REASSEMBLE_DEBUG("reassembler %p insert : segments %lu, bytes %lu", assy, assy->stat.insert_segments, assy->stat.insert_bytes); TCP_REASSEMBLE_DEBUG("reassembler %p remove : segments %lu, bytes %lu", assy, assy->stat.remove_segments, assy->stat.remove_bytes); TCP_REASSEMBLE_DEBUG("reassembler %p consume : segments %lu, bytes %lu", assy, assy->stat.consume_segments, assy->stat.consume_bytes); TCP_REASSEMBLE_DEBUG("reassembler %p retrans bypass : segments %lu, bytes %lu", assy, assy->stat.retrans_bypass_segments, assy->stat.retrans_bypass_bytes); TCP_REASSEMBLE_DEBUG("reassembler %p overload bypass : segments %lu, bytes %lu", assy, assy->stat.overload_bypass_segments, assy->stat.overload_bypass_bytes); TCP_REASSEMBLE_DEBUG("reassembler %p timeout discard : segments %lu, bytes %lu", assy, assy->stat.timeout_discard_segments, assy->stat.timeout_discard_bytes); }