#include #include "log.h" #include "list.h" #include "interval_tree.h" #include "tcp_reassembly.h" #define TCP_REASSEMBLY_LOG_DEBUG(format, ...) LOG_DEBUG("tcp_reassembly", format, ##__VA_ARGS__) #define TCP_REASSEMBLY_LOG_ERROR(format, ...) LOG_ERROR("tcp_reassembly", format, ##__VA_ARGS__) struct tcp_segment_private { uint64_t ts; uint64_t id; struct list_head lru; struct interval_tree_node node; struct tcp_segment seg; void *data; // flexible array member }; struct tcp_reassembly { uint64_t max_timeout; uint64_t max_seg_num; uint64_t cur_seg_num; uint64_t sum_seg_num; struct list_head list; struct rb_root_cached root; uint32_t recv_next; }; struct tcp_segment *tcp_segment_new(uint32_t seq, const void *data, uint32_t len) { struct tcp_segment_private *p = (struct tcp_segment_private *)calloc(1, sizeof(struct tcp_segment_private) + len); if (!p) { TCP_REASSEMBLY_LOG_ERROR("calloc failed"); return NULL; } p->node.start = seq; p->node.last = (uint64_t)seq + (uint64_t)len - 1; p->data = (char *)p + sizeof(struct tcp_segment_private); memcpy(p->data, data, len); p->seg.len = len; p->seg.data = p->data; return &p->seg; } void tcp_segment_free(struct tcp_segment *seg) { if (seg) { struct tcp_segment_private *p = container_of(seg, struct tcp_segment_private, seg); free(p); } } struct tcp_reassembly *tcp_reassembly_new(uint64_t max_timeout, uint64_t max_seg_num) { struct tcp_reassembly *assembler = (struct tcp_reassembly *)calloc(1, sizeof(struct tcp_reassembly)); if (!assembler) { TCP_REASSEMBLY_LOG_ERROR("calloc failed"); return NULL; } assembler->max_timeout = max_timeout; assembler->max_seg_num = max_seg_num; assembler->cur_seg_num = 0; assembler->root = RB_ROOT_CACHED; INIT_LIST_HEAD(&assembler->list); return assembler; } void tcp_reassembly_free(struct tcp_reassembly *assembler) { if (assembler) { while (!list_empty(&assembler->list)) { struct tcp_segment_private *p = list_first_entry(&assembler->list, struct tcp_segment_private, lru); assembler->cur_seg_num--; list_del(&p->lru); interval_tree_remove(&p->node, &assembler->root); free(p); } free(assembler); } } /* * return: 1: push tcp segment success (segment overlap) * return: 0: push tcp segment success * return: -1: push tcp segment failed (no space) * return: -2: push tcp segment failed (segment repeat) */ int tcp_reassembly_push(struct tcp_reassembly *assembler, struct tcp_segment *seg, uint64_t now) { if (assembler == NULL) { return -1; } if (assembler->cur_seg_num >= assembler->max_seg_num) { TCP_REASSEMBLY_LOG_ERROR("assembler %p is full", assembler); return -1; } int ret = 0; struct tcp_segment_private *p = container_of(seg, struct tcp_segment_private, seg); struct interval_tree_node *node = interval_tree_iter_first(&assembler->root, p->node.start, p->node.last); if (node) { do { struct tcp_segment_private *t = container_of(node, struct tcp_segment_private, node); if (t->node.start == p->node.start && t->node.last == p->node.last) { TCP_REASSEMBLY_LOG_DEBUG("assembler %p push segment %p [%lu, %lu] failed, segment repeat", assembler, seg, p->node.start, p->node.last); return -2; } } while ((node = interval_tree_iter_next(node, p->node.start, p->node.last))); TCP_REASSEMBLY_LOG_DEBUG("assembler %p push segment %p [%lu, %lu], but segment overlap", assembler, seg, p->node.start, p->node.last); ret = 1; } else { TCP_REASSEMBLY_LOG_DEBUG("assembler %p push segment %p [%lu, %lu]", assembler, seg, p->node.start, p->node.last); } p->ts = now; p->id = assembler->sum_seg_num++; list_add_tail(&p->lru, &assembler->list); interval_tree_insert(&p->node, &assembler->root); assembler->cur_seg_num++; return ret; } struct tcp_segment *tcp_reassembly_pop(struct tcp_reassembly *assembler) { if (assembler == NULL) { return NULL; } struct interval_tree_node *node = interval_tree_iter_first(&assembler->root, assembler->recv_next, assembler->recv_next); if (node == NULL) { return NULL; } uint64_t overlap = 0; uint64_t min_id = UINT64_MAX; struct tcp_segment_private *oldest = NULL; while (node) { struct tcp_segment_private *p = container_of(node, struct tcp_segment_private, node); if (p->id < min_id) { min_id = p->id; oldest = p; } node = interval_tree_iter_next(node, assembler->recv_next, assembler->recv_next); } list_del(&oldest->lru); interval_tree_remove(&oldest->node, &assembler->root); assembler->cur_seg_num--; if (oldest->node.start < assembler->recv_next) { // trim overlap overlap = assembler->recv_next - oldest->node.start; oldest->seg.len -= overlap; oldest->seg.data = (char *)oldest->data + overlap; } TCP_REASSEMBLY_LOG_DEBUG("assembler %p pop segment %p [%lu, %lu], trim overlap %lu", assembler, &oldest->seg, oldest->node.start, oldest->node.last, overlap); // update recv_next assembler->recv_next = uint32_add(assembler->recv_next, oldest->seg.len); return &oldest->seg; } struct tcp_segment *tcp_reassembly_expire(struct tcp_reassembly *assembler, uint64_t now) { if (assembler == NULL) { return NULL; } if (list_empty(&assembler->list)) { return NULL; } struct tcp_segment_private *p = list_first_entry(&assembler->list, struct tcp_segment_private, lru); if (now - p->ts >= assembler->max_timeout) { assembler->cur_seg_num--; list_del(&p->lru); interval_tree_remove(&p->node, &assembler->root); TCP_REASSEMBLY_LOG_DEBUG("assembler %p expire segment %p [%lu, %lu]", assembler, &p->seg, p->node.start, p->node.last); return &p->seg; } else { return NULL; } } void tcp_reassembly_inc_recv_next(struct tcp_reassembly *assembler, uint32_t offset) { if (assembler == NULL) { return; } assembler->recv_next = uint32_add(assembler->recv_next, offset); TCP_REASSEMBLY_LOG_DEBUG("assembler %p inc recv_next %u to %lu", assembler, offset, assembler->recv_next); } void tcp_reassembly_set_recv_next(struct tcp_reassembly *assembler, uint32_t seq) { if (assembler == NULL) { return; } assembler->recv_next = seq; TCP_REASSEMBLY_LOG_DEBUG("assembler %p set recv_next %u", assembler, seq); } uint32_t tcp_reassembly_get_recv_next(struct tcp_reassembly *assembler) { if (assembler == NULL) { return 0; } return assembler->recv_next; } const char *tcp_segment_get_data(const struct tcp_segment *seg) { if (seg == NULL) { return NULL; } else { return (const char *)seg->data; } } uint16_t tcp_segment_get_len(const struct tcp_segment *seg) { if (seg == NULL) { return 0; } else { return seg->len; } }