#include "checksum.h" #include uint16_t checksum(const void *data, int len) { uint16_t *ptr = (uint16_t *)data; uint32_t sum = 0; while (len > 1) { sum += *ptr++; len -= 2; } if (len == 1) { sum += *(uint8_t *)ptr; } sum = (sum >> 16) + (sum & 0xFFFF); sum += (sum >> 16); return (uint16_t)~sum; } uint16_t checksum_v4(const void *l4_hdr_ptr, uint16_t l4_total_len, uint8_t l4_proto, struct in_addr *src_addr, struct in_addr *dst_addr) { uint16_t *ip_src = (uint16_t *)src_addr; uint16_t *ip_dst = (uint16_t *)dst_addr; const uint16_t *buffer = (u_int16_t *)l4_hdr_ptr; uint32_t sum = 0; size_t len = l4_total_len; while (len > 1) { sum += *buffer++; if (sum & 0x80000000) { sum = (sum & 0xFFFF) + (sum >> 16); } len -= 2; } if (len & 1) { sum += *((uint8_t *)buffer); } sum += *(ip_src++); sum += *ip_src; sum += *(ip_dst++); sum += *ip_dst; sum += htons(l4_proto); sum += htons(l4_total_len); while (sum >> 16) { sum = (sum & 0xFFFF) + (sum >> 16); } return ((uint16_t)(~sum)); } uint16_t checksum_v6(const void *l4_hdr_ptr, uint16_t l4_total_len, uint8_t l4_proto, struct in6_addr *src_addr, struct in6_addr *dst_addr) { uint16_t *ip_src = (uint16_t *)src_addr; uint16_t *ip_dst = (uint16_t *)dst_addr; const uint16_t *buffer = (u_int16_t *)l4_hdr_ptr; uint32_t sum = 0; size_t len = l4_total_len; while (len > 1) { sum += *buffer++; if (sum & 0x80000000) { sum = (sum & 0xFFFF) + (sum >> 16); } len -= 2; } if (len & 1) { sum += *((uint8_t *)buffer); } for (int i = 0; i < 8; i++) { sum += *ip_src; ip_src++; } for (int i = 0; i < 8; i++) { sum += *ip_dst; ip_dst++; } sum += htons(l4_proto); sum += htons(l4_total_len); while (sum >> 16) { sum = (sum & 0xFFFF) + (sum >> 16); } return ((uint16_t)(~sum)); }