diff --git a/Makefile b/Makefile index ba03f38..00b0620 100644 --- a/Makefile +++ b/Makefile @@ -2,7 +2,7 @@ CXX=g++ CXXFLAGS=-O2 -g -Wall -Wextra -std=c++17 CXXLIBS= -OBJS=TunDevice.o util.o main.o +OBJS=UdpVpn.o TunDevice.o util.o main.o TARGET=congestvpn all: $(TARGET) diff --git a/UdpVpn.cpp b/UdpVpn.cpp new file mode 100644 index 0000000..d6609b0 --- /dev/null +++ b/UdpVpn.cpp @@ -0,0 +1,160 @@ +#include "UdpVpn.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include + +UdpVpn::UdpVpn() + : _bound(false), _stopped(false), _has_peer(false), _tun_dev("cvpn%d") +{ + memset(&_serv_addr, 0, sizeof(_serv_addr)); + memset(&_peer_addr, 0, sizeof(_peer_addr)); + + _socket = socket(AF_INET6, SOCK_DGRAM, 0); + if(_socket < 0) + throw UdpVpn::InitializationError("Cannot create socket", errno, true); +} + +UdpVpn::~UdpVpn() { + close(_socket); +} + +void UdpVpn::bind(in_port_t port) { + bind(in6addr_any, port); +} +void UdpVpn::bind(const struct in6_addr& bind_addr6, in_port_t port) { + int rc; + + _serv_addr.sin6_family = AF_INET6; + _serv_addr.sin6_port = htons(port); + _serv_addr.sin6_addr = bind_addr6; + + rc = ::bind( + _socket, (const struct sockaddr*)&_serv_addr, sizeof(_serv_addr)); + if(rc < 0) { + throw UdpVpn::InitializationError("Cannot bind socket", errno, true); + } + + debugf("> Listening on port %d\n", port); + _bound = true; +} + +void UdpVpn::set_peer(const sockaddr_in6& peer_addr) { + memcpy(&_peer_addr, &peer_addr, sizeof(_peer_addr)); + _has_peer = true; + + char peer_addr_str[INET6_ADDRSTRLEN]; + inet_ntop( + AF_INET6, &(peer_addr.sin6_addr), peer_addr_str, INET6_ADDRSTRLEN); + debugf("Set peer to %s:%d\n", peer_addr_str, ntohs(peer_addr.sin6_port)); +} + +void UdpVpn::run() { + int rc; + int start_at_fd = 0; // read from polled fds in round-robin fashion + int cur_fd; + int nfds = 2; + struct pollfd poll_fds[2]; + + // poll_fds[0]: tun device + poll_fds[0].fd = _tun_dev.get_fd(); + poll_fds[0].events = POLLIN; + + // poll_fds[1]: UDP socket device + poll_fds[1].fd = _socket; + poll_fds[1].events = POLLIN; + + while(!_stopped) { + rc = poll(poll_fds, nfds, -1); + + if(rc < 0) { + if(errno == EINTR) // Interrupt. + continue; + throw UdpVpn::NetError( + "Error polling from interface", errno, true); + } + else if(rc == 0) // Nothing to read + continue; + + cur_fd = start_at_fd; + do { + if(poll_fds[cur_fd].revents & POLLIN) { + if(cur_fd == 0) + receive_from_tun(); + else if(cur_fd == 1) + receive_from_udp(); + break; + } + + cur_fd = (cur_fd + 1) % nfds; + } while(cur_fd != start_at_fd); + + start_at_fd = (start_at_fd + 1) % nfds; + } +} + +void UdpVpn::receive_from_tun() { + // We know that there is data available -- use `read()` + char buffer[1500]; + size_t nread = _tun_dev.read(buffer, 1500); + + if(nread == 0) + return; + + kdebugf("Transmitting packet of size %d to peer\n", nread); + send_over_udp(buffer, nread); +} + +void UdpVpn::receive_from_udp() { + ssize_t nread; + char buffer[1500]; + struct sockaddr_in6 peer_addr; + socklen_t peer_addr_len = sizeof(peer_addr); + + nread = recvfrom(_socket, buffer, 1500, MSG_WAITALL, + (struct sockaddr*) &peer_addr, &peer_addr_len); + + if(nread < 0) + throw UdpVpn::NetError("Cannot receive datagram", errno, true); + if(nread == 0) + return; + + if(peer_addr.sin6_family != AF_INET6) { + debugf("WARNING: Received non-ipv6 family datagram %d. Ignoring.\n", + peer_addr.sin6_family); + return; + } + if(peer_addr_len != sizeof(peer_addr)) { + debugf("WARNING: received unexpected source address length %u." + "Ignoring.\n", + peer_addr_len); + return; + } + + set_peer(peer_addr); + + // Reinject into tun + kdebugf("Receiving packet of size %d from peer\n", nread); + _tun_dev.write(buffer, nread); +} + +size_t UdpVpn::send_over_udp(const char* data, size_t len) { + ssize_t nsent; + + if(!_has_peer) { + debugf("Dropping packet to be transmitted: no peer.\n"); + return 0; + } + + nsent = sendto(_socket, data, len, MSG_CONFIRM, + (const struct sockaddr*) &_peer_addr, sizeof(_peer_addr)); + if(nsent < 0) + throw NetError("Could not send UDP packet", errno, true); + + return (size_t) nsent; +} diff --git a/UdpVpn.hpp b/UdpVpn.hpp new file mode 100644 index 0000000..3b190be --- /dev/null +++ b/UdpVpn.hpp @@ -0,0 +1,64 @@ +#pragma once + +#include +#include + +#include "util.hpp" +#include "TunDevice.hpp" + +/** Handles UDP communication */ + +class UdpVpn { + public: + class InitializationError : public MsgException { + public: + InitializationError( + const std::string& msg, + int code=0, + bool is_perror=false) + : MsgException(msg, code, is_perror) {} + }; + class NetError : public MsgException { + public: + NetError( + const std::string& msg, + int code=0, + bool is_perror=false) + : MsgException(msg, code, is_perror) {} + }; + + UdpVpn(); + ~UdpVpn(); + + void bind(in_port_t port); + void bind(const struct in6_addr& bind_addr6, in_port_t port); + + int get_socket_fd() const { return _socket; } + const TunDevice& get_tun_dev() const { return _tun_dev; } + bool is_bound() const { return _bound; } + + // Sets the peer address + void set_peer(const sockaddr_in6& peer_addr); + void unset_peer() { _has_peer = false; } + + // Run the server. + void run(); + + // Stop the server. Can be called from an interrupt. + void stop() { _stopped = true; } + + protected: + void receive_from_tun(); + void receive_from_udp(); + + size_t send_over_udp(const char* data, size_t len); + + private: + int _socket; + bool _bound; + bool _stopped; + bool _has_peer; + struct sockaddr_in6 _serv_addr, _peer_addr; + + TunDevice _tun_dev; +}; diff --git a/main.cpp b/main.cpp index b1199c1..ac13077 100644 --- a/main.cpp +++ b/main.cpp @@ -1,36 +1,133 @@ #include #include +#include #include +#include -#include "TunDevice.hpp" +#include "UdpVpn.hpp" +#include "util.hpp" -static const size_t BUFFER_SIZE = 1500; -static volatile bool stopped = false; +struct ProgOptions { + bool listen; + in6_addr bind_addr; + in_port_t bind_port; + + bool has_peer; + in6_addr server_addr; + in_port_t server_port; +}; + +static UdpVpn* vpn_instance = nullptr; void stop_sig_handler(int signal) { printf("Received signal %d. Stopping.\n", signal); - stopped = true; + if(vpn_instance != nullptr) + vpn_instance->stop(); } -int main(void) { - char read_buffer[BUFFER_SIZE]; +bool parse_options(int argc, char** argv, ProgOptions& opts) { + int option; + memset(&opts, 0, sizeof(opts)); + + while((option = getopt(argc, argv, ":d:l:b:s:p:")) >= 0) { + switch(option) { + case 'd': // debug + debug = atoi(optarg); + break; + case 'l': // listen, aka "call bind()" + opts.listen = true; + opts.bind_port = atoi(optarg); + break; + case 'b': // bind to address + inet_pton(AF_INET6, optarg, &opts.bind_addr); + break; + case 's': // server -- initial peer + opts.has_peer = true; + inet_pton(AF_INET6, optarg, &opts.server_addr); + break; + case 'p': // port -- initial peer address + opts.server_port = atoi(optarg); + break; + case ':': + fprintf(stderr, "Option %c requires a value.\n", optopt); + return false; + default: + case '?': + fprintf(stderr, "Unknown option %c.\n", optopt); + return false; + } + } + + return true; +} + +int main(int argc, char** argv) { + debug = 0; + ProgOptions program_options; + parse_options(argc, argv, program_options); + + printf("==== Options ====\n"); + printf("Debug level: %d\n", debug); + printf("Bind socket: %d\n", program_options.listen); + if(program_options.listen) { + char bind_addr_str[INET6_ADDRSTRLEN]; + inet_ntop( + AF_INET6, &(program_options.bind_addr), + bind_addr_str, INET6_ADDRSTRLEN); + printf("\taddr: %s\n", bind_addr_str); + printf("\tport: %d\n", program_options.bind_port); + } + printf("Start with peer: %d\n", program_options.has_peer); + if(program_options.has_peer) { + char peer_addr_str[INET6_ADDRSTRLEN]; + inet_ntop( + AF_INET6, &(program_options.server_addr), + peer_addr_str, INET6_ADDRSTRLEN); + printf("\taddr: %s\n", peer_addr_str); + printf("\tport: %d\n", program_options.server_port); + } + printf("=== END OPTIONS ==\n\n"); signal(SIGINT, stop_sig_handler); - try { - TunDevice tun_dev("cvpn%d"); - printf("Tunnel device opened: <%s>, fd <%d>.\n", - tun_dev.get_dev_name().c_str(), tun_dev.get_fd()); - while(!stopped) { - size_t packet_size = tun_dev.read_packet(read_buffer, BUFFER_SIZE); - if(packet_size > 0) { - printf("Received packet of size %lu.\n", packet_size); - tun_dev.write_packet(read_buffer, packet_size); + try { + UdpVpn vpn; + vpn_instance = &vpn; + + if(program_options.listen) { + vpn.bind(program_options.bind_addr, program_options.bind_port); + } + + if(program_options.has_peer) { + if(program_options.server_port == 0) { + fprintf(stderr, + "Initial peer set without a port -- ignoring.\n"); + } + else { + struct sockaddr_in6 server_addr; + server_addr.sin6_family = AF_INET6; + memcpy(&server_addr.sin6_addr, + &program_options.server_addr, + sizeof(server_addr.sin6_addr)); + server_addr.sin6_port = htons(program_options.server_port); + vpn.set_peer(server_addr); } } + + printf("Starting to listen...\n"); + vpn.run(); + vpn_instance = nullptr; + printf("Shutting down.\n"); } catch(const TunDevice::InitializationError& exn) { - fprintf(stderr, "ERROR: %s\n", exn.what()); + fprintf(stderr, "TUN INIT ERROR: %s\n", exn.what()); + } catch(const TunDevice::NetError& exn) { + fprintf(stderr, "TUN NET ERROR: %s\n", exn.what()); + } catch(const UdpVpn::InitializationError& exn) { + fprintf(stderr, "VPN INIT ERROR: %s\n", exn.what()); + } catch(const UdpVpn::NetError& exn) { + fprintf(stderr, "VPN NET ERROR: %s\n", exn.what()); } + return 0; }