diff --git a/src/checksum.c b/src/checksum.c index 6f17240..01b73b4 100644 --- a/src/checksum.c +++ b/src/checksum.c @@ -95,11 +95,11 @@ unsigned short checksum_ipv4(struct s_ipv4_addr ip_src, header->ip_dest = ip_dest; header->zeros = 0x0; header->proto = proto; - header->len = htonl((unsigned int) length); + header->len = htons(length); memcpy(buffer + sizeof(struct s_ipv4_pseudo), payload, (int) length); - sum = checksum(buffer, sizeof(struct s_ipv6_pseudo) + (int) length); + sum = checksum(buffer, sizeof(struct s_ipv4_pseudo) + (int) length); free(buffer); diff --git a/src/ipv4.h b/src/ipv4.h index 03d4bb3..faee16a 100644 --- a/src/ipv4.h +++ b/src/ipv4.h @@ -53,7 +53,7 @@ struct s_ipv4_pseudo { struct s_ipv4_addr ip_dest; /* 32 b; destination address */ unsigned char zeros; /* 8 b */ unsigned char proto; /* 8 b; protocol in payload */ - unsigned int len; /* 32 b; payload length */ + unsigned short len; /* 16 b; payload length */ } __attribute__ ((__packed__)); int ipv4(struct s_ethernet *eth, char *packet); diff --git a/src/udp.c b/src/udp.c index cc8f424..4d58889 100644 --- a/src/udp.c +++ b/src/udp.c @@ -58,7 +58,20 @@ int udp_ipv4(struct s_ethernet *eth, struct s_ipv4 *ip4, char *payload, /* parse UDP header */ udp = (struct s_udp *) payload; - /* TODO: checksum recheck */ + /* checksum recheck */ + if (udp->checksum != 0x0000) { + orig_checksum = udp->checksum; + udp->checksum = 0; + udp->checksum = checksum_ipv4(ip4->ip_src, ip4->ip_dest, + payload_size, IPPROTO_UDP, + (unsigned char *) udp); + + if (udp->checksum != orig_checksum) { + /* packet is corrupted and shouldn't be processed */ + printf("[Debug] Wrong checksum\n"); + return 1; + } + } /* find connection in NAT */ connection = nat_in(nat4_udp, ip4->ip_src, udp->port_src, udp->port_dest); @@ -185,16 +198,17 @@ int udp_ipv6(struct s_ethernet *eth, struct s_ipv6 *ip6, char *payload) udp->port_src = connection->ipv4_port_src; /* compute UDP checksum */ - udp->checksum = 0; - /* TODO: checksum computation; in IPv4 it's optional in UDP */ + udp->checksum = 0x0; + udp->checksum = checksum_ipv4(ip4->ip_src, ip4->ip_dest, + htons(ip6->len), IPPROTO_UDP, + (unsigned char *) payload); /* copy the payload data (with new checksum) */ memcpy(packet + sizeof(struct s_ipv4), payload, htons(ip6->len)); /* compute IPv4 checksum */ - ip4->checksum = checksum_ipv4(ip4->ip_src, ip4->ip_dest, - htons(ip4->len), IPPROTO_UDP, - (unsigned char *) udp); + ip4->checksum = 0x0; + ip4->checksum = checksum(ip4, sizeof(struct s_ipv4)); /* send translated packet */ printf("[Debug] transmitting\n");