#include #include "list.h" #include "tcp_reassembly.h" #include "interval_tree.h" struct tcp_segment_private { uint64_t ts; uint64_t id; struct list_head lru; struct interval_tree_node node; struct tcp_segment seg; void *data; // flexible array member }; struct tcp_reassembly { uint64_t max_timeout; uint64_t max_seg_num; uint64_t cur_seg_num; uint64_t sum_seg_num; struct list_head list; struct rb_root_cached root; uint32_t recv_next; }; struct tcp_segment *tcp_segment_new(uint32_t seq, const void *data, uint32_t len) { struct tcp_segment_private *p = (struct tcp_segment_private *)calloc(1, sizeof(struct tcp_segment_private) + len); if (!p) { TCP_REASSEMBLY_LOG_ERROR("calloc failed"); return NULL; } p->node.start = seq; p->node.last = (uint64_t)seq + (uint64_t)len - 1; p->data = (char *)p + sizeof(struct tcp_segment_private); memcpy(p->data, data, len); p->seg.len = len; p->seg.data = p->data; return &p->seg; } void tcp_segment_free(struct tcp_segment *seg) { if (seg) { struct tcp_segment_private *p = container_of(seg, struct tcp_segment_private, seg); free(p); } } struct tcp_reassembly *tcp_reassembly_new(uint64_t max_timeout, uint64_t max_seg_num) { struct tcp_reassembly *assembler = (struct tcp_reassembly *)calloc(1, sizeof(struct tcp_reassembly)); if (!assembler) { TCP_REASSEMBLY_LOG_ERROR("calloc failed"); return NULL; } assembler->max_timeout = max_timeout; assembler->max_seg_num = max_seg_num; assembler->cur_seg_num = 0; assembler->root = RB_ROOT_CACHED; INIT_LIST_HEAD(&assembler->list); return assembler; } void tcp_reassembly_free(struct tcp_reassembly *assembler) { if (assembler) { while (!list_empty(&assembler->list)) { struct tcp_segment_private *p = list_first_entry(&assembler->list, struct tcp_segment_private, lru); assembler->cur_seg_num--; list_del(&p->lru); interval_tree_remove(&p->node, &assembler->root); free(p); } free(assembler); } } // return: 1: success (seg overlap) // return: 0: success // return: -1: failed (no space) int tcp_reassembly_push(struct tcp_reassembly *assembler, struct tcp_segment *seg, uint64_t now) { if (assembler->cur_seg_num >= assembler->max_seg_num) { TCP_REASSEMBLY_LOG_ERROR("assembler is full"); return -1; } int ret = 0; struct tcp_segment_private *p = container_of(seg, struct tcp_segment_private, seg); if (interval_tree_iter_first(&assembler->root, p->node.start, p->node.last)) { TCP_REASSEMBLY_LOG_DEBUG("seg overlap"); ret = 1; } p->ts = now; p->id = assembler->sum_seg_num++; list_add_tail(&p->lru, &assembler->list); interval_tree_insert(&p->node, &assembler->root); assembler->cur_seg_num++; return ret; } struct tcp_segment *tcp_reassembly_pop(struct tcp_reassembly *assembler) { struct interval_tree_node *node; node = interval_tree_iter_first(&assembler->root, assembler->recv_next, assembler->recv_next); if (node == NULL) { return NULL; } uint64_t min_id = UINT64_MAX; struct tcp_segment_private *oldest = NULL; while (node) { struct tcp_segment_private *p = container_of(node, struct tcp_segment_private, node); if (p->id < min_id) { min_id = p->id; oldest = p; } node = interval_tree_iter_next(node, assembler->recv_next, assembler->recv_next); } list_del(&oldest->lru); interval_tree_remove(&oldest->node, &assembler->root); assembler->cur_seg_num--; if (oldest->node.start < assembler->recv_next) { // trim overlap uint64_t overlap = assembler->recv_next - oldest->node.start; oldest->seg.len -= overlap; oldest->seg.data = (char *)oldest->data + overlap; } // update recv_next assembler->recv_next = oldest->node.last + 1; if (assembler->recv_next > UINT32_MAX) { assembler->recv_next = assembler->recv_next % 4294967296; } return &oldest->seg; } struct tcp_segment *tcp_reassembly_expire(struct tcp_reassembly *assembler, uint64_t now) { if (list_empty(&assembler->list)) { return NULL; } struct tcp_segment_private *p = list_first_entry(&assembler->list, struct tcp_segment_private, lru); if (now - p->ts >= assembler->max_timeout) { assembler->cur_seg_num--; list_del(&p->lru); interval_tree_remove(&p->node, &assembler->root); return &p->seg; } else { return NULL; } } void tcp_reassembly_inc_recv_next(struct tcp_reassembly *assembler, uint32_t offset) { assembler->recv_next += offset; if (assembler->recv_next > UINT32_MAX) { assembler->recv_next = assembler->recv_next % 4294967296; } } void tcp_reassembly_set_recv_next(struct tcp_reassembly *assembler, uint32_t seq) { assembler->recv_next = seq; } uint32_t tcp_reassembly_get_recv_next(struct tcp_reassembly *assembler) { return assembler->recv_next; }