diff --git a/UdpVpn.cpp b/UdpVpn.cpp index 1b6b2c3..22d8776 100644 --- a/UdpVpn.cpp +++ b/UdpVpn.cpp @@ -76,13 +76,18 @@ size_t UdpVpn::read_from_tun(char* buffer, size_t len) { } size_t UdpVpn::read_from_tun(TunnelledPacket& packet) { - size_t nread = - read_from_tun(packet.get_payload(), packet.get_payload_space()); + size_t payload_space = packet.get_payload_space(); + size_t nread = read_from_tun(packet.get_payload(), payload_space); packet.set_payload_size(nread); if(!packet.parse_as_ipv6()) { debugf("Ignoring packet with invalid header\n"); return 0; } + if(nread != packet.get_ipv6_header().packet_length()) { + debugf("Ignoring packet with bad size (expected %d, got %d, buffer %d)\n", + packet.get_ipv6_header().packet_length(), nread, payload_space); + return 0; + } return nread; } diff --git a/VpnPacket.cpp b/VpnPacket.cpp index f003dd4..2f4c0a4 100644 --- a/VpnPacket.cpp +++ b/VpnPacket.cpp @@ -31,7 +31,7 @@ VpnPacketTLV VpnPacket::first_tlv() { } size_t VpnPacket::get_tunnelled_mtu(size_t udp_mtu) { - return udp_mtu - VPN_HEADER_BYTES - TLV_HEADER_BYTES; + return udp_mtu - OUTER_HEADERS_BYTES - VPN_HEADER_BYTES - TLV_HEADER_BYTES; } uint32_t VpnPacket::get_seqno() const { @@ -102,6 +102,12 @@ void VpnPacketTLV::set_payload_size(uint16_t size) { _packet.increase_payload_size(size - old_size); } +uint16_t VpnPacketTLV::get_payload_space() const { + return _packet.get_payload_space() + - _packet.get_payload_size() + + get_payload_size(); +} + VpnPacket::PayloadType VpnPacketTLV::get_type() const { return (VpnPacket::PayloadType)(*(uint8_t*)(get_data())); } diff --git a/VpnPacket.hpp b/VpnPacket.hpp index dfa3284..cb27d16 100644 --- a/VpnPacket.hpp +++ b/VpnPacket.hpp @@ -159,10 +159,7 @@ class VpnPacketTLV { /// Set the current payload size void set_payload_size(uint16_t size); /// Get the total available raw data space - uint16_t get_payload_space() const { - return _packet.get_payload_space() - - _packet.get_payload_size() - - get_payload_size(); } + uint16_t get_payload_space() const; /// Get a pointer to the raw data (const version) const char* get_data() const { diff --git a/ip_header.cpp b/ip_header.cpp index 6464266..8e0dfd6 100644 --- a/ip_header.cpp +++ b/ip_header.cpp @@ -15,5 +15,6 @@ bool parse_ipv6_header(const char* data, size_t len, IPv6Header& header) { memcpy(&(header.source), data + 8, 16); memcpy(&(header.dest), data + 24, 16); + header.payload_length = ntohs(*(uint16_t*)(data + 4)); return true; } diff --git a/ip_header.hpp b/ip_header.hpp index 1d86666..698e334 100644 --- a/ip_header.hpp +++ b/ip_header.hpp @@ -5,6 +5,9 @@ struct IPv6Header { in6_addr source; in6_addr dest; + uint16_t payload_length; + + uint16_t packet_length() const { return payload_length + 40; } }; /** Parse an IPv6 header, filling `header`. Returns `true` on success. */