Refactor TCP reassembly, the session knows where the TCP segment comes from: raw packet or tcp segment queue

This commit is contained in:
luwenpeng
2024-04-02 16:21:39 +08:00
parent a509f0ce3b
commit e8e60cee6d
25 changed files with 678 additions and 1419 deletions

View File

@@ -1,355 +1,206 @@
#include <stdlib.h>
#include <string.h>
#include <stdio.h>
#include <assert.h>
#include "list.h"
#include "list.h"
#include "tcp_reassembly.h"
#include "interval_tree.h"
struct segment
struct tcp_segment_private
{
struct interval_tree_node tree_node;
struct list_head list_node;
uint64_t time;
uint64_t ts;
uint64_t id;
char *payload; // Flexible array member
struct list_head lru;
struct interval_tree_node node;
struct tcp_segment seg;
void *data; // flexible array member
};
struct tcp_reassembly
{
uint8_t enable;
uint32_t max_timeout;
uint32_t max_segments;
uint32_t max_bytes;
struct tcp_reassembly_stat stat;
uint64_t max_timeout;
uint64_t max_seg_num;
uint64_t cur_seg_num;
uint64_t sum_seg_num;
struct rb_root_cached tree_root;
struct list_head list_root;
uint64_t rcv_nxt; // what we want to receive next
struct list_head list;
struct rb_root_cached root;
uint32_t recv_next;
};
/******************************************************************************
* Private API
******************************************************************************/
/*
* The next routines deal with comparing 32 bit unsigned ints
* and worry about wraparound (automatic with unsigned arithmetic).
*/
static inline bool before(uint32_t seq1, uint32_t seq2)
struct tcp_segment *tcp_segment_new(uint32_t seq, const void *data, uint32_t len)
{
return (int32_t)(seq1 - seq2) < 0;
struct tcp_segment_private *p = (struct tcp_segment_private *)calloc(1, sizeof(struct tcp_segment_private) + len);
if (!p)
{
TCP_REASSEMBLY_LOG_ERROR("calloc failed");
return NULL;
}
p->node.start = seq;
p->node.last = (uint64_t)seq + (uint64_t)len - 1;
p->data = (char *)p + sizeof(struct tcp_segment_private);
memcpy(p->data, data, len);
p->seg.len = len;
p->seg.data = p->data;
return &p->seg;
}
static int check_options(const struct tcp_reassembly_options *opts)
void tcp_segment_free(struct tcp_segment *seg)
{
if (opts == NULL)
if (seg)
{
TCP_REASSEMBLE_ERROR("invalid options");
struct tcp_segment_private *p = container_of(seg, struct tcp_segment_private, seg);
free(p);
}
}
struct tcp_reassembly *tcp_reassembly_new(uint64_t max_timeout, uint64_t max_seg_num)
{
struct tcp_reassembly *assembler = (struct tcp_reassembly *)calloc(1, sizeof(struct tcp_reassembly));
if (!assembler)
{
TCP_REASSEMBLY_LOG_ERROR("calloc failed");
return NULL;
}
assembler->max_timeout = max_timeout;
assembler->max_seg_num = max_seg_num;
assembler->cur_seg_num = 0;
assembler->root = RB_ROOT_CACHED;
INIT_LIST_HEAD(&assembler->list);
return assembler;
}
void tcp_reassembly_free(struct tcp_reassembly *assembler)
{
if (assembler)
{
while (!list_empty(&assembler->list))
{
struct tcp_segment_private *p = list_first_entry(&assembler->list, struct tcp_segment_private, lru);
assembler->cur_seg_num--;
list_del(&p->lru);
interval_tree_remove(&p->node, &assembler->root);
free(p);
}
free(assembler);
}
}
// return: 1: success (seg overlap)
// return: 0: success
// return: -1: failed (no space)
int tcp_reassembly_push(struct tcp_reassembly *assembler, struct tcp_segment *seg, uint64_t now)
{
if (assembler->cur_seg_num >= assembler->max_seg_num)
{
TCP_REASSEMBLY_LOG_ERROR("assembler is full");
return -1;
}
if (opts->enable)
int ret = 0;
struct tcp_segment_private *p = container_of(seg, struct tcp_segment_private, seg);
if (interval_tree_iter_first(&assembler->root, p->node.start, p->node.last))
{
if (opts->max_timeout < 1 || opts->max_timeout > 60000)
{
TCP_REASSEMBLE_ERROR("invalid max_timeout: %u, supported range: [1, 60000]", opts->max_timeout);
return -1;
}
TCP_REASSEMBLY_LOG_DEBUG("seg overlap");
ret = 1;
}
return 0;
p->ts = now;
p->id = assembler->sum_seg_num++;
list_add_tail(&p->lru, &assembler->list);
interval_tree_insert(&p->node, &assembler->root);
assembler->cur_seg_num++;
return ret;
}
/******************************************************************************
* Public API
******************************************************************************/
struct tcp_reassembly *tcp_reassembly_new(struct tcp_reassembly_options *opts)
struct tcp_segment *tcp_reassembly_pop(struct tcp_reassembly *assembler)
{
if (check_options(opts) == -1)
struct interval_tree_node *node;
node = interval_tree_iter_first(&assembler->root, assembler->recv_next, assembler->recv_next);
if (node == NULL)
{
return NULL;
}
struct tcp_reassembly *assy = (struct tcp_reassembly *)calloc(1, sizeof(struct tcp_reassembly));
if (assy == NULL)
{
return NULL;
}
assy->enable = opts->enable;
assy->max_timeout = opts->max_timeout;
assy->max_segments = opts->max_segments;
assy->max_bytes = opts->max_bytes;
assy->tree_root = RB_ROOT_CACHED;
INIT_LIST_HEAD(&assy->list_root);
return assy;
}
void tcp_reassembly_free(struct tcp_reassembly *assy)
{
struct segment *seg = NULL;
struct interval_tree_node *tree_node = NULL;
if (assy)
{
while ((tree_node = interval_tree_iter_first(&assy->tree_root, 0, UINT64_MAX)))
{
seg = container_of(tree_node, struct segment, tree_node);
interval_tree_remove(&seg->tree_node, &assy->tree_root);
list_del(&seg->list_node);
free(seg);
seg = NULL;
}
free(assy);
assy = NULL;
}
}
void tcp_reassembly_init(struct tcp_reassembly *assy, uint32_t syn_seq)
{
if (!assy->enable)
{
return;
}
assy->rcv_nxt = syn_seq + 1;
TCP_REASSEMBLE_DEBUG("reassembler %p init expect seq %lu", assy, assy->rcv_nxt);
}
void tcp_reassembly_expire(struct tcp_reassembly *assy, uint64_t now)
{
if (!assy->enable)
{
return;
}
uint64_t len;
struct segment *seg = NULL;
while (!list_empty(&assy->list_root))
{
seg = list_first_entry(&assy->list_root, struct segment, list_node);
if (seg->time + assy->max_timeout > now)
{
break;
}
len = seg->tree_node.last - seg->tree_node.start + 1;
assy->stat.timeout_discard_segments++;
assy->stat.timeout_discard_bytes += len;
assy->stat.curr_segments--;
assy->stat.curr_bytes -= len;
TCP_REASSEMBLE_DEBUG("reassembler %p expire segment %p [%lu, %lu] (time: %lu, now: %lu)", assy, seg, seg->tree_node.start, seg->tree_node.last, seg->time, now);
interval_tree_remove(&seg->tree_node, &assy->tree_root);
list_del(&seg->list_node);
free(seg);
seg = NULL;
}
}
void tcp_reassembly_insert(struct tcp_reassembly *assy, uint32_t offset, const char *payload, uint32_t len, uint64_t now)
{
if (!assy->enable || len == 0)
{
return;
}
uint64_t low = (uint64_t)offset;
uint64_t high = (uint64_t)offset + (uint64_t)len - 1; // from uint32_t to uint64_t, so no overflow
assy->stat.insert_segments++;
assy->stat.insert_bytes += len;
if (assy->max_segments > 0 && assy->stat.curr_segments >= assy->max_segments)
{
assy->stat.overload_bypass_segments++;
assy->stat.overload_bypass_bytes += len;
TCP_REASSEMBLE_DEBUG("reassembler %p insert [%lu, %lu] failed, reach max packets %u", assy, low, high, assy->max_segments);
return;
}
if (assy->max_bytes > 0 && assy->stat.curr_bytes >= assy->max_bytes)
{
assy->stat.overload_bypass_segments++;
assy->stat.overload_bypass_bytes += len;
TCP_REASSEMBLE_DEBUG("reassembler %p insert [%lu, %lu] failed, reach max bytes %u", assy, low, high, assy->max_bytes);
return;
}
if (before(offset + len, assy->rcv_nxt))
{
assy->stat.retrans_bypass_segments++;
assy->stat.retrans_bypass_bytes += len;
TCP_REASSEMBLE_DEBUG("reassembler %p insert [%lu, %lu] failed, less the expect seq %lu", assy, low, high, assy->rcv_nxt);
return;
}
struct segment *seg = (struct segment *)calloc(1, sizeof(struct segment) + len);
if (seg == NULL)
{
assy->stat.overload_bypass_segments++;
assy->stat.overload_bypass_bytes += len;
TCP_REASSEMBLE_DEBUG("reassembler %p insert [%lu, %lu] failed, calloc segment failed", assy, low, high);
return;
}
seg->tree_node.start = low;
seg->tree_node.last = high;
seg->time = now;
seg->id = assy->stat.insert_segments;
seg->payload = (char *)seg + sizeof(struct segment);
memcpy(seg->payload, payload, len);
list_add_tail(&seg->list_node, &assy->list_root);
interval_tree_insert(&seg->tree_node, &assy->tree_root);
TCP_REASSEMBLE_DEBUG("reassembler %p insert segment %p [%lu, %lu]", assy, seg, low, high);
assy->stat.curr_segments++;
assy->stat.curr_bytes += len;
}
const char *tcp_reassembly_peek(struct tcp_reassembly *assy, uint32_t *len)
{
*len = 0;
if (!assy->enable)
{
return NULL;
}
uint64_t id = UINT64_MAX;
struct segment *seg = NULL;
struct interval_tree_node *tree_node = NULL;
struct interval_tree_node *oldest_node = NULL;
tree_node = interval_tree_iter_first(&assy->tree_root, assy->rcv_nxt, assy->rcv_nxt);
while (tree_node)
{
seg = container_of(tree_node, struct segment, tree_node);
if (seg->id < id)
{
id = seg->id;
oldest_node = tree_node;
}
tree_node = interval_tree_iter_next(tree_node, assy->rcv_nxt, assy->rcv_nxt);
}
if (oldest_node == NULL)
{
return NULL;
}
uint64_t payload_len = oldest_node->last - oldest_node->start + 1;
seg = container_of(oldest_node, struct segment, tree_node);
if (oldest_node->start < assy->rcv_nxt)
{
uint64_t overlap = assy->rcv_nxt - oldest_node->start;
*len = (uint16_t)(payload_len - overlap);
TCP_REASSEMBLE_DEBUG("reassembler %p peek [%lu, +∞], found segment %p [%lu, %lu] (left overlap: %lu)", assy, assy->rcv_nxt, seg, oldest_node->start, oldest_node->last, overlap);
return seg->payload + overlap;
}
TCP_REASSEMBLE_DEBUG("reassembler %p peek [%lu, +∞], found segment %p [%lu, %lu]", assy, assy->rcv_nxt, seg, oldest_node->start, oldest_node->last);
*len = (uint16_t)payload_len;
return seg->payload;
}
void tcp_reassembly_consume(struct tcp_reassembly *assy, uint32_t len)
{
if (!assy->enable || len == 0)
{
return;
}
/*
* https://www.ietf.org/rfc/rfc0793.txt
*
* This space ranges from 0 to 2**32 - 1.
* Since the space is finite, all arithmetic dealing with sequence
* numbers must be performed modulo 2**32. This unsigned arithmetic
* preserves the relationship of sequence numbers as they cycle from
* 2**32 - 1 to 0 again. There are some subtleties to computer modulo
* arithmetic, so great care should be taken in programming the
* comparison of such values. The symbol "=<" means "less than or equal"
* (modulo 2**32).
*
* UINT32_MAX = 4294967295
* 2^32 = 4294967296
* 2^32 - 1 = 4294967295
* seq range: [0, 4294967295]
* seq range: [0, UINT32_MAX]
*/
uint64_t old_exp_seq = assy->rcv_nxt;
assy->rcv_nxt += len;
if (assy->rcv_nxt > UINT32_MAX)
{
assy->rcv_nxt = assy->rcv_nxt % 4294967296;
}
uint64_t new_exp_seq = assy->rcv_nxt;
TCP_REASSEMBLE_DEBUG("reassembler %p consume [%lu, %lu], update expect seq %lu -> %lu", assy, old_exp_seq, old_exp_seq + len - 1, old_exp_seq, new_exp_seq);
assy->stat.consume_segments++;
assy->stat.consume_bytes += len;
struct interval_tree_node *node = interval_tree_iter_first(&assy->tree_root, old_exp_seq, old_exp_seq + len - 1);
uint64_t min_id = UINT64_MAX;
struct tcp_segment_private *oldest;
while (node)
{
if (before(node->last, new_exp_seq))
struct tcp_segment_private *p = container_of(node, struct tcp_segment_private, node);
if (p->id < min_id)
{
struct segment *seg = container_of(node, struct segment, tree_node);
uint32_t len = node->last - node->start + 1;
assy->stat.remove_segments++;
assy->stat.remove_bytes += len;
assy->stat.curr_segments--;
assy->stat.curr_bytes -= len;
TCP_REASSEMBLE_DEBUG("reassembler %p consume [%lu, %lu], delete segment %p [%lu, %lu]", assy, old_exp_seq, old_exp_seq + len - 1, node, node->start, node->last);
interval_tree_remove(node, &assy->tree_root);
list_del(&seg->list_node);
free(seg);
node = interval_tree_iter_first(&assy->tree_root, old_exp_seq, old_exp_seq + len - 1);
}
else
{
node = interval_tree_iter_next(node, old_exp_seq, old_exp_seq + len - 1);
min_id = p->id;
oldest = p;
}
node = interval_tree_iter_next(node, assembler->recv_next, assembler->recv_next);
}
list_del(&oldest->lru);
interval_tree_remove(&oldest->node, &assembler->root);
assembler->cur_seg_num--;
if (oldest->node.start < assembler->recv_next)
{
// trim overlap
uint64_t overlap = assembler->recv_next - oldest->node.start;
oldest->seg.len -= overlap;
oldest->seg.data = (char *)oldest->data + overlap;
}
// update recv_next
assembler->recv_next = oldest->node.last + 1;
if (assembler->recv_next > UINT32_MAX)
{
assembler->recv_next = assembler->recv_next % 4294967296;
}
return &oldest->seg;
}
struct tcp_reassembly_stat *tcp_reassembly_get_stat(struct tcp_reassembly *assy)
struct tcp_segment *tcp_reassembly_expire(struct tcp_reassembly *assembler, uint64_t now)
{
if (!assy->enable)
if (list_empty(&assembler->list))
{
return NULL;
}
return &assy->stat;
}
void tcp_reassembly_print_stat(struct tcp_reassembly *assy)
{
if (!assy->enable)
struct tcp_segment_private *p = list_first_entry(&assembler->list, struct tcp_segment_private, lru);
if (now - p->ts >= assembler->max_timeout)
{
return;
assembler->cur_seg_num--;
list_del(&p->lru);
interval_tree_remove(&p->node, &assembler->root);
return &p->seg;
}
else
{
return NULL;
}
TCP_REASSEMBLE_DEBUG("reassembler %p current : segments %lu, bytes %lu", assy, assy->stat.curr_segments, assy->stat.curr_bytes);
TCP_REASSEMBLE_DEBUG("reassembler %p insert : segments %lu, bytes %lu", assy, assy->stat.insert_segments, assy->stat.insert_bytes);
TCP_REASSEMBLE_DEBUG("reassembler %p remove : segments %lu, bytes %lu", assy, assy->stat.remove_segments, assy->stat.remove_bytes);
TCP_REASSEMBLE_DEBUG("reassembler %p consume : segments %lu, bytes %lu", assy, assy->stat.consume_segments, assy->stat.consume_bytes);
TCP_REASSEMBLE_DEBUG("reassembler %p retrans bypass : segments %lu, bytes %lu", assy, assy->stat.retrans_bypass_segments, assy->stat.retrans_bypass_bytes);
TCP_REASSEMBLE_DEBUG("reassembler %p overload bypass : segments %lu, bytes %lu", assy, assy->stat.overload_bypass_segments, assy->stat.overload_bypass_bytes);
TCP_REASSEMBLE_DEBUG("reassembler %p timeout discard : segments %lu, bytes %lu", assy, assy->stat.timeout_discard_segments, assy->stat.timeout_discard_bytes);
}
void tcp_reassembly_inc_recv_next(struct tcp_reassembly *assembler, uint32_t offset)
{
assembler->recv_next += offset;
if (assembler->recv_next > UINT32_MAX)
{
assembler->recv_next = assembler->recv_next % 4294967296;
}
}
void tcp_reassembly_set_recv_next(struct tcp_reassembly *assembler, uint32_t seq)
{
assembler->recv_next = seq;
}
uint32_t tcp_reassembly_get_recv_next(struct tcp_reassembly *assembler)
{
return assembler->recv_next;
}