diff --git a/UdpVpn.cpp b/UdpVpn.cpp index 69a9ec8..6a0d4bd 100644 --- a/UdpVpn.cpp +++ b/UdpVpn.cpp @@ -14,7 +14,7 @@ static const size_t VPN_MTU = 1460; // TODO determine this -- issue #3 UdpVpn::UdpVpn() - : _stopped(false), _vpn_mtu(VPN_MTU), _tun_dev("cvpn%d") + : _stopped(false), _vpn_mtu(VPN_MTU), _tun_dev("cvpn%d"), _peer(nullptr) { _socket = socket(AF_INET6, SOCK_DGRAM, 0); if(_socket < 0) @@ -74,7 +74,7 @@ size_t UdpVpn::read_from_tun(char* buffer, size_t len) { return _tun_dev.read(buffer, len); } -size_t UdpVpn::read_from_tun(VpnPacket& packet) { +size_t UdpVpn::read_from_tun(TunnelledPacket& packet) { size_t nread = read_from_tun(packet.get_payload(), packet.get_payload_space()); packet.set_payload_size(nread); @@ -118,10 +118,75 @@ size_t UdpVpn::read_from_udp(VpnPacket& packet, sockaddr_in6& peer_addr) { size_t nread = read_from_udp(packet.get_data(), packet.get_data_space(), peer_addr); packet.set_data_size(nread); - if(!packet.parse_as_ipv6()) { - debugf("Ignoring packet with invalid header\n"); - return 0; - } return nread; } +size_t UdpVpn::transmit_to_peer(VpnPacket& packet) { + if(!_peer) { + debugf("Dropping packet: no peer yet.\n"); + return 0; + } + return _peer->write(packet); +} + +void UdpVpn::receive_from_tun() { + VpnPacket packet(_vpn_mtu); + TunnelledPacket tunnelled = TunnelledPacket::create(packet); + size_t nread = read_from_tun(tunnelled); + if(nread == 0) + return; + + if(!_peer) { + debugf("Dropping packet: no peer yet.\n"); + return; + } + packet.set_peer(_peer.get()); + + kdebugf("Transmitting %s -> %s, size %d\n", + format_address(tunnelled.get_ipv6_header().source.s6_addr), + format_address(tunnelled.get_ipv6_header().dest.s6_addr), + nread); + + packet.prepare_for_sending(); + transmit_to_peer(packet); +} + +void UdpVpn::receive_from_udp() { + VpnPacket packet(_vpn_mtu); + sockaddr_in6 peer_ext_addr; + size_t nread = read_from_udp(packet, peer_ext_addr); + if(nread == 0) + return; + + // If we don't have a peer yet -- we're just setting the peer to nullptr. + packet.set_peer(_peer.get()); + + for(VpnPacketTLV tlv=packet.first_tlv(); + !tlv.past_the_end(); + tlv.seek_next_tlv()) + { + switch(tlv.get_type()) { + case VpnPacket::PAYLOAD_TYPE_TUNNELLED: + { + TunnelledPacket tunnelled(tlv); + acquire_peer(tunnelled, peer_ext_addr); + receive_tunnelled_tlv(tunnelled); + } + break; + + case VpnPacket::PAYLOAD_TYPE_UNDEF: + default: + debugf("#%d+%lu: ignoring TLV with bad type %d.\n", + packet.get_seqno(), tlv.get_offset(), + tlv.get_type()); + break; + } + } +} + +void UdpVpn::receive_tunnelled_tlv(TunnelledPacket& packet) { + // Reinject into tun + kdebugf("Reinjecting tunnelled packet of size %d\n", + packet.get_payload_size()); + _tun_dev.write(packet.get_payload(), packet.get_payload_size()); +} diff --git a/UdpVpn.hpp b/UdpVpn.hpp index c0bdc47..fa98b8c 100644 --- a/UdpVpn.hpp +++ b/UdpVpn.hpp @@ -42,18 +42,27 @@ class UdpVpn { void stop() { _stopped = true; } protected: - virtual void receive_from_tun() = 0; - virtual void receive_from_udp() = 0; + virtual void acquire_peer( + TunnelledPacket& packet, + const sockaddr_in6& peer_ext_addr) = 0; size_t read_from_tun(char* buffer, size_t len); - size_t read_from_tun(VpnPacket& packet); + size_t read_from_tun(TunnelledPacket& packet); size_t read_from_udp(char* buffer, size_t len, sockaddr_in6& peer_addr); size_t read_from_udp(VpnPacket& packet, sockaddr_in6& peer_addr); + size_t transmit_to_peer(VpnPacket& packet); + + void receive_from_tun(); + void receive_from_udp(); + + void receive_tunnelled_tlv(TunnelledPacket& packet); + int _socket; bool _stopped; size_t _vpn_mtu; TunDevice _tun_dev; + std::unique_ptr _peer; }; diff --git a/UdpVpnClient.cpp b/UdpVpnClient.cpp index 63a5bed..0363c0b 100644 --- a/UdpVpnClient.cpp +++ b/UdpVpnClient.cpp @@ -4,55 +4,14 @@ #include "ip_header.hpp" UdpVpnClient::UdpVpnClient(const struct sockaddr_in6& server) : UdpVpn() { - memset(&_server_addr, 0, sizeof(_server_addr)); - set_server(server); + _peer = std::make_unique(this, server, in6addr_any); } -void UdpVpnClient::set_server(const struct sockaddr_in6& server_addr) { - if(server_addr.sin6_family != AF_INET6) - throw UdpVpn::InitializationError("Server address must be IPv6"); - - memcpy(&_server_addr, &server_addr, sizeof(_server_addr)); -} - -void UdpVpnClient::receive_from_tun() { - VpnPacket packet(_vpn_mtu); - size_t nread = read_from_tun(packet); - if(nread == 0) +void UdpVpnClient::acquire_peer( + TunnelledPacket& packet, + const sockaddr_in6&) +{ + if(!packet.parse_as_ipv6()) return; - - kdebugf("Transmitting %s -> %s, size %d\n", - format_address(packet.get_ipv6_header().source.s6_addr), - format_address(packet.get_ipv6_header().dest.s6_addr), - nread); - - packet.prepare_for_sending(); - write_to_server(packet); -} - -void UdpVpnClient::receive_from_udp() { - VpnPacket packet(_vpn_mtu); - sockaddr_in6 peer_addr; - size_t nread = read_from_udp(packet, peer_addr); - if(nread == 0) - return; - - // Reinject into tun - kdebugf("Receiving packet #%u of size %d from %s\n", - packet.get_seqno(), - nread, - format_address(packet.get_ipv6_header().source.s6_addr)); - _tun_dev.write(packet.get_payload(), packet.get_payload_size()); -} - -size_t UdpVpnClient::write_to_server(const VpnPacket& packet) { - ssize_t nsent; - - nsent = sendto(_socket, packet.get_data(), packet.get_data_size(), - MSG_CONFIRM, - (const struct sockaddr*) &_server_addr, sizeof(_server_addr)); - if(nsent < 0) - throw NetError("Could not send UDP packet", errno, true); - - return (size_t) nsent; + _peer->set_int_addr(packet.get_ipv6_header().source); } diff --git a/UdpVpnClient.hpp b/UdpVpnClient.hpp index 86d455c..f3d371b 100644 --- a/UdpVpnClient.hpp +++ b/UdpVpnClient.hpp @@ -7,12 +7,7 @@ class UdpVpnClient: public UdpVpn { UdpVpnClient(const struct sockaddr_in6& server); protected: - void set_server(const struct sockaddr_in6& server_addr); - - virtual void receive_from_tun(); - virtual void receive_from_udp(); - - size_t write_to_server(const VpnPacket& packet); - - struct sockaddr_in6 _server_addr; + virtual void acquire_peer( + TunnelledPacket& packet, + const sockaddr_in6& peer_ext_addr); }; diff --git a/UdpVpnServer.cpp b/UdpVpnServer.cpp index bb0468b..bb8a060 100644 --- a/UdpVpnServer.cpp +++ b/UdpVpnServer.cpp @@ -15,6 +15,27 @@ UdpVpnServer::UdpVpnServer(const struct in6_addr& bind_addr6, in_port_t port) bind(bind_addr6, port); } +void UdpVpnServer::acquire_peer( + TunnelledPacket& packet, + const sockaddr_in6& peer_ext_addr) +{ + if(_peer) + return; // Refusing a connection if we already have one + // TODO: reset state at some point/if connection broken + + if(!packet.parse_as_ipv6()) + return; + const in6_addr& peer_inner_addr = packet.get_ipv6_header().source; + _peer = std::make_unique(this, peer_ext_addr, peer_inner_addr); + + packet.get_packet().set_peer(_peer.get()); + + debugf("Got new peer %s:%d -- %s\n", + format_address(peer_ext_addr.sin6_addr.s6_addr), + htons(peer_ext_addr.sin6_port), + format_address(peer_inner_addr.s6_addr)); +} + void UdpVpnServer::bind(const struct in6_addr& bind_addr6, in_port_t port) { int rc; @@ -30,66 +51,3 @@ void UdpVpnServer::bind(const struct in6_addr& bind_addr6, in_port_t port) { debugf("> Listening on port %d\n", port); } - - -std::shared_ptr UdpVpnServer::get_peer_for_ip( - const in6_addr& peer_addr) -{ - auto peer_iter = _peers.find(peer_addr); - if(peer_iter == _peers.end()) // Unknown peer - return nullptr; - return peer_iter->second; -} - -void UdpVpnServer::receive_from_tun() { - VpnPacket packet(_vpn_mtu); - size_t nread = read_from_tun(packet); - if(nread == 0) - return; - - // Recover VpnPeer -- or drop if new - const in6_addr& peer_inner_addr = packet.get_ipv6_header().dest; - std::shared_ptr peer = get_peer_for_ip(peer_inner_addr); - if(!peer) { - debugf("Dropping packet for destination %s -- unknown peer.\n", - format_address(peer_inner_addr.s6_addr)); - return; - } - packet.set_peer(peer); - - kdebugf("Transmitting %s -> %s, size %d\n", - format_address(packet.get_ipv6_header().source.s6_addr), - format_address(packet.get_ipv6_header().dest.s6_addr), - nread); - packet.prepare_for_sending(); - peer->write(packet); -} - -void UdpVpnServer::receive_from_udp() { - VpnPacket packet(_vpn_mtu); - sockaddr_in6 peer_addr; - size_t nread = read_from_udp(packet, peer_addr); - if(nread == 0) - return; - - // Recover VpnPeer -- or create if new - const in6_addr& peer_inner_addr = packet.get_ipv6_header().source; - std::shared_ptr peer = get_peer_for_ip(peer_inner_addr); - if(!peer) { - peer = std::make_shared(this, peer_addr, peer_inner_addr); - _peers.insert({peer_inner_addr, peer}); - - debugf("Got new peer %s:%d -- %s\n", - format_address(peer_addr.sin6_addr.s6_addr), - htons(peer_addr.sin6_port), - format_address(peer_inner_addr.s6_addr)); - } - packet.set_peer(peer); - - // Reinject into tun - kdebugf("Receiving packet #%u of size %d from %s\n", - packet.get_seqno(), - nread, - format_address(packet.get_ipv6_header().source.s6_addr)); - _tun_dev.write(packet.get_payload(), packet.get_payload_size()); -} diff --git a/UdpVpnServer.hpp b/UdpVpnServer.hpp index 2bcb73f..664f0f4 100644 --- a/UdpVpnServer.hpp +++ b/UdpVpnServer.hpp @@ -10,14 +10,11 @@ class UdpVpnServer: public UdpVpn { UdpVpnServer(in_port_t port); UdpVpnServer(const struct in6_addr& bind_addr6, in_port_t port); protected: + virtual void acquire_peer( + TunnelledPacket& packet, + const sockaddr_in6& peer_ext_addr); + void bind(const struct in6_addr& bind_addr6, in_port_t port); - /** Get the peer associated to this (internal) IP. */ - std::shared_ptr get_peer_for_ip(const in6_addr& peer_addr); - - virtual void receive_from_tun(); - virtual void receive_from_udp(); - struct sockaddr_in6 _bind_addr; - std::unordered_map> _peers; }; diff --git a/VpnPacket.cpp b/VpnPacket.cpp index 715ecc7..3dfcf61 100644 --- a/VpnPacket.cpp +++ b/VpnPacket.cpp @@ -6,7 +6,6 @@ const size_t VpnPacket::VPN_HEADER_BYTES = 8; const size_t VpnPacket::TLV_HEADER_BYTES = 3; -uint32_t VpnPacket::_next_general_seqno = 0; static const size_t OUTER_HEADERS_BYTES = 40 /* IPv6 header */ + 8 /* UDP header */; @@ -27,15 +26,8 @@ VpnPacket::~VpnPacket() { delete[] _data; } -VpnPacket::iterator VpnPacket::begin() { - return iterator(VpnPacketTLV(*this, 0)); -} -VpnPacket::iterator VpnPacket::end() { - return iterator(VpnPacketTLV(*this, get_payload_size())); -} - -bool VpnPacket::parse_as_ipv6() { - return parse_ipv6_header(get_payload(), get_payload_size(), _ipv6_header); +VpnPacketTLV VpnPacket::first_tlv() { + return VpnPacketTLV(*this, 0); } uint32_t VpnPacket::get_seqno() const { @@ -60,9 +52,9 @@ void VpnPacket::upon_reception() { } uint32_t VpnPacket::next_seqno() { - if(_peer) - return _peer->next_seqno(); - return _next_general_seqno++; + if(!_peer) + throw PeerNotSet(); + return _peer->next_seqno(); } @@ -70,6 +62,10 @@ VpnPacketTLV::VpnPacketTLV(VpnPacket& packet, size_t payload_offset) : _packet(packet), _tlv_pos(payload_offset) {} +VpnPacketTLV::VpnPacketTLV(const VpnPacketTLV& other) : + _packet(other._packet), _tlv_pos(other._tlv_pos) +{} + VpnPacketTLV VpnPacketTLV::create( VpnPacket& packet, VpnPacket::PayloadType type) { @@ -83,6 +79,15 @@ VpnPacketTLV VpnPacketTLV::create( return tlv; } +VpnPacketTLV VpnPacketTLV::next_tlv() { + size_t next_offset = _tlv_pos + get_payload_size(); + return VpnPacketTLV(_packet, next_offset); +} + +void VpnPacketTLV::seek_next_tlv() { + _tlv_pos = _tlv_pos + VpnPacket::TLV_HEADER_BYTES + get_payload_size(); +} + uint16_t VpnPacketTLV::get_payload_size() const { return *(uint16_t*)(get_data() + 1); } @@ -111,3 +116,7 @@ TunnelledPacket::TunnelledPacket(const VpnPacketTLV& copy) TunnelledPacket TunnelledPacket::create(VpnPacket& packet) { return VpnPacketTLV::create(packet, VpnPacket::PAYLOAD_TYPE_TUNNELLED); } + +bool TunnelledPacket::parse_as_ipv6() { + return parse_ipv6_header(get_payload(), get_payload_size(), _ipv6_header); +} diff --git a/VpnPacket.hpp b/VpnPacket.hpp index 7114c9e..f4df618 100644 --- a/VpnPacket.hpp +++ b/VpnPacket.hpp @@ -29,20 +29,19 @@ * * Where * - Type is one of the values from PayloadType below; + * - Sender ID is an arbitrary value, recommended to be randomly chosen * - Payload size is the size of the payload (excluding headers), in bytes; * - Payload is an arbitrary value, defined by Type */ class VpnPeer; -class VpnPacketTLVIterator; +class VpnPacketTLV; class VpnPacket { public: static const size_t VPN_HEADER_BYTES; static const size_t TLV_HEADER_BYTES; - typedef VpnPacketTLVIterator iterator; - enum PayloadType { PAYLOAD_TYPE_UNDEF, ///< Undefined packet type PAYLOAD_TYPE_TUNNELLED, ///< A tunnelled packet @@ -50,19 +49,15 @@ class VpnPacket { PAYLOAD_TYPE_REMB, ///< Receiver Estimated Maximum Bitrate }; + class PeerNotSet: public std::exception {}; + VpnPacket(size_t mtu); ~VpnPacket(); - iterator begin(); - iterator end(); + VpnPacketTLV first_tlv(); /// Set packet peer -- used for sequence numbers - void set_peer(std::shared_ptr peer) { _peer = peer; } - - /// Try to parse the packet as IPv6, return `false` upon failure. - bool parse_as_ipv6(); - bool ipv6_parsed() const { return _ipv6_parsed; } - const IPv6Header& get_ipv6_header() const { return _ipv6_header; } + void set_peer(VpnPeer* peer) { _peer = peer; } /// Get a pointer to the packet payload (const version) const char* get_payload() const { return _data + VPN_HEADER_BYTES; } @@ -117,23 +112,19 @@ class VpnPacket { inline uint32_t next_seqno(); private: - std::shared_ptr _peer; + VpnPeer* _peer; // raw pointer: we do not own the peer in any way char* _data; size_t _data_space, _data_size; - bool _ipv6_parsed; - IPv6Header _ipv6_header; - uint32_t _reception_timestamp; - - static uint32_t _next_general_seqno; }; /** Base class for a TLV contained in a VpnPacket */ class VpnPacketTLV { public: VpnPacketTLV(VpnPacket& packet, size_t payload_offset); + VpnPacketTLV(const VpnPacketTLV& other); static VpnPacketTLV create( VpnPacket& packet, @@ -142,6 +133,15 @@ class VpnPacketTLV { const VpnPacket& get_packet() const { return _packet; } VpnPacket& get_packet() { return _packet; } + /// Get the next TLV in this packet. + VpnPacketTLV next_tlv(); + /// Point this object to the next TLV (ie. `next_tlv` in place) + void seek_next_tlv(); + + /// Check whether the current TLV is past the packet's end + bool past_the_end() const { + return _tlv_pos >= _packet.get_payload_size(); } + /// Get the offset in the packet size_t get_offset() const { return _tlv_pos; } @@ -155,6 +155,11 @@ class VpnPacketTLV { uint16_t get_payload_size() const; /// Set the current payload size void set_payload_size(uint16_t size); + /// Get the total available raw data space + uint16_t get_payload_space() const { + return _packet.get_payload_space() + - _packet.get_payload_size() + - get_payload_size(); } /// Get a pointer to the raw data (const version) const char* get_data() const { @@ -185,36 +190,6 @@ class VpnPacketTLV { friend class TunnelledPacket; }; -/** An iterator over the VpnPacketTLVs of a VpnPacket */ -class VpnPacketTLVIterator { - public: - using iterator_category = std::input_iterator_tag; - using value_type = VpnPacketTLV; - using difference_type = ptrdiff_t; - using pointer = VpnPacketTLV*; - using reference = VpnPacketTLV&; - - VpnPacketTLVIterator(const VpnPacketTLV& tlv) : _tlv(tlv) {} - - VpnPacketTLVIterator& operator++(); - VpnPacketTLVIterator operator++(int) { - VpnPacketTLVIterator tmp(*this); - operator++(); - return tmp; - } - - bool operator==(const VpnPacketTLVIterator& other) { - return _tlv == other._tlv; } - bool operator!=(const VpnPacketTLVIterator& other) { - return !(operator==(other)); } - bool past_the_end() const { - return _tlv.get_offset() >= _tlv.get_packet().get_payload_size(); } - reference operator*() { return _tlv; } - - private: - VpnPacketTLV _tlv; -}; - /** A packet sent through the VPN tunnel. * * This must instantiated just before filling it with data. */ @@ -223,4 +198,13 @@ class TunnelledPacket: public VpnPacketTLV { TunnelledPacket(VpnPacket& packet, size_t payload_offset); TunnelledPacket(const VpnPacketTLV& copy); static TunnelledPacket create(VpnPacket& packet); + + /// Try to parse the packet as IPv6, return `false` upon failure. + bool parse_as_ipv6(); + bool ipv6_parsed() const { return _ipv6_parsed; } + const IPv6Header& get_ipv6_header() const { return _ipv6_header; } + + private: + bool _ipv6_parsed; + IPv6Header _ipv6_header; }; diff --git a/VpnPeer.cpp b/VpnPeer.cpp index 8dc7470..b4c01e5 100644 --- a/VpnPeer.cpp +++ b/VpnPeer.cpp @@ -6,10 +6,13 @@ #include VpnPeer::VpnPeer(UdpVpn* vpn, const sockaddr_in6& ext_addr, - const in6_addr int_addr) + const in6_addr& int_addr) : _vpn(vpn), _ext_addr(ext_addr), _int_addr(int_addr), _next_seqno(0) {} +void VpnPeer::set_int_addr(const in6_addr& int_addr) { + memcpy(&_int_addr, &int_addr, sizeof(_int_addr)); +} size_t VpnPeer::write(const char* data, size_t len) { ssize_t nsent; diff --git a/VpnPeer.hpp b/VpnPeer.hpp index 1e65db6..66fafd1 100644 --- a/VpnPeer.hpp +++ b/VpnPeer.hpp @@ -20,11 +20,13 @@ class VpnPeer { }; VpnPeer(UdpVpn* vpn, const sockaddr_in6& ext_addr, - const in6_addr int_addr); + const in6_addr& int_addr); const sockaddr_in6& get_ext_addr() const { return _ext_addr; } const in6_addr& get_int_addr() const { return _int_addr; } + void set_int_addr(const in6_addr& int_addr); + size_t write(const char* data, size_t len); size_t write(const VpnPacket& packet);