Compare commits

...

2 commits

13 changed files with 84 additions and 197 deletions

View file

@ -7,19 +7,20 @@
#include <errno.h> #include <errno.h>
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
#include <stdio.h>
#include "TunDevice.hpp" #include "TunDevice.hpp"
static const size_t TUN_MTU = 1500; // TODO determine this cleanly static const size_t TUN_MTU = 1500; // TODO determine this cleanly
TunDevice::TunDevice(const std::string& dev) TunDevice::TunDevice(const char* dev)
{ {
struct ifreq ifr; struct ifreq ifr;
int fd; int fd;
if( (fd = open("/dev/net/tun", O_RDWR)) < 0 ) { if( (fd = open("/dev/net/tun", O_RDWR)) < 0 ) {
throw TunDevice::InitializationError( perror("Tun device: cannot open /dev/net/tun: ");
"Cannot open /dev/net/tun", errno, true); exit(1);
} }
memset(&ifr, 0, sizeof(ifr)); memset(&ifr, 0, sizeof(ifr));
@ -30,16 +31,18 @@ TunDevice::TunDevice(const std::string& dev)
* IFF_NO_PI - Do not provide packet information * IFF_NO_PI - Do not provide packet information
*/ */
ifr.ifr_flags = IFF_TUN | IFF_NO_PI; ifr.ifr_flags = IFF_TUN | IFF_NO_PI;
if(!dev.empty()) { if(dev != nullptr) {
if(dev.size() >= IFNAMSIZ - 2) if(strlen(dev) >= IFNAMSIZ - 2) {
throw TunDevice::InitializationError("Device name is too long."); fprintf(stderr, "Tun device: device name is too long.\n");
strncpy(ifr.ifr_name, dev.c_str(), IFNAMSIZ-1); exit(1);
}
strncpy(ifr.ifr_name, dev, IFNAMSIZ-1);
} }
if(ioctl(fd, TUNSETIFF, (void *) &ifr) < 0){ if(ioctl(fd, TUNSETIFF, (void *) &ifr) < 0){
close(fd); close(fd);
throw TunDevice::InitializationError( perror("Tun device: tunnel interface failed [TUNSETIFF]: ");
"Tunnel interface failed [TUNSETIFF]", errno, true); exit(1);
} }
_dev_name = ifr.ifr_name; _dev_name = ifr.ifr_name;
_fd = fd; _fd = fd;
@ -50,16 +53,16 @@ TunDevice::TunDevice(const std::string& dev)
if(ioctl(sockfd, SIOCGIFFLAGS, (void*) &ifr) < 0) { if(ioctl(sockfd, SIOCGIFFLAGS, (void*) &ifr) < 0) {
close(fd); close(fd);
close(sockfd); close(sockfd);
throw TunDevice::InitializationError( perror("Tun device: could not get tunnel interface flags: ");
"Could not get tunnel interface flags", errno, true); exit(1);
} }
ifr.ifr_flags |= IFF_UP | IFF_RUNNING; ifr.ifr_flags |= IFF_UP | IFF_RUNNING;
if(ioctl(sockfd, SIOCSIFFLAGS, (void*) &ifr) < 0) { if(ioctl(sockfd, SIOCSIFFLAGS, (void*) &ifr) < 0) {
close(fd); close(fd);
close(sockfd); close(sockfd);
throw TunDevice::InitializationError( perror("Tun device: could not bring tunnel interface up: ");
"Could not bring tunnel interface up", errno, true); exit(1);
} }
close(sockfd); close(sockfd);
@ -77,7 +80,7 @@ uint16_t TunDevice::get_mtu() const {
struct ifreq ifr; struct ifreq ifr;
ifr.ifr_addr.sa_family = AF_INET6; ifr.ifr_addr.sa_family = AF_INET6;
strncpy(ifr.ifr_name, _dev_name.c_str(), sizeof(ifr.ifr_name)-1); strncpy(ifr.ifr_name, _dev_name, sizeof(ifr.ifr_name)-1);
if (ioctl(sockfd, SIOCGIFMTU, (caddr_t)&ifr) < 0) if (ioctl(sockfd, SIOCGIFMTU, (caddr_t)&ifr) < 0)
return 0; return 0;
close(sockfd); close(sockfd);
@ -89,7 +92,7 @@ bool TunDevice::set_mtu(uint16_t mtu) {
struct ifreq ifr; struct ifreq ifr;
ifr.ifr_addr.sa_family = AF_INET6; ifr.ifr_addr.sa_family = AF_INET6;
strncpy(ifr.ifr_name, _dev_name.c_str(), sizeof(ifr.ifr_name)-1); strncpy(ifr.ifr_name, _dev_name, sizeof(ifr.ifr_name)-1);
ifr.ifr_mtu = mtu; ifr.ifr_mtu = mtu;
if (ioctl(sockfd, SIOCSIFMTU, (caddr_t)&ifr) < 0) if (ioctl(sockfd, SIOCSIFMTU, (caddr_t)&ifr) < 0)
return false; return false;
@ -102,8 +105,8 @@ size_t TunDevice::poll_packet(char* read_buffer, size_t buf_size, int timeout) {
if(poll_rc < 0) { if(poll_rc < 0) {
if(errno == EINTR) // Interrupt. if(errno == EINTR) // Interrupt.
return 0; return 0;
throw TunDevice::NetError( perror("Tun device: error polling from interface: ");
"Error polling from interface", errno, true); exit(1);
} }
else if(poll_rc == 0 || (_poll_fd.revents & POLLIN) == 0) { else if(poll_rc == 0 || (_poll_fd.revents & POLLIN) == 0) {
// Nothing to read // Nothing to read
@ -116,8 +119,8 @@ size_t TunDevice::poll_packet(char* read_buffer, size_t buf_size, int timeout) {
size_t TunDevice::read(char* read_buffer, size_t buf_size) { size_t TunDevice::read(char* read_buffer, size_t buf_size) {
int nread = ::read(_fd, read_buffer, buf_size); int nread = ::read(_fd, read_buffer, buf_size);
if(nread < 0) { if(nread < 0) {
throw TunDevice::NetError( perror("Tun device: error reading from interface: ");
"Error reading from interface", errno, true); exit(1);
} }
_last_read_size = nread; _last_read_size = nread;
return _last_read_size; return _last_read_size;
@ -126,8 +129,8 @@ size_t TunDevice::read(char* read_buffer, size_t buf_size) {
size_t TunDevice::write(const char* data, size_t len) { size_t TunDevice::write(const char* data, size_t len) {
int nwritten = ::write(_fd, data, len); int nwritten = ::write(_fd, data, len);
if(nwritten < 0) { if(nwritten < 0) {
throw TunDevice::NetError( perror("Tun device: error writing to interface: ");
"Error writing to interface: ", errno, true); exit(1);
} }
return nwritten; return nwritten;
} }

View file

@ -1,32 +1,14 @@
#pragma once #pragma once
#include <string>
#include <poll.h> #include <poll.h>
#include "util.hpp" #include "util.hpp"
class TunDevice { class TunDevice {
public: public:
class InitializationError : public MsgException { TunDevice(const char* dev);
public:
InitializationError(
const std::string& msg,
int code=0,
bool is_perror=false)
: MsgException(msg, code, is_perror) {}
};
class NetError : public MsgException {
public:
NetError(
const std::string& msg,
int code=0,
bool is_perror=false)
: MsgException(msg, code, is_perror) {}
};
TunDevice(const std::string& dev);
~TunDevice(); ~TunDevice();
const std::string& get_dev_name() const { return _dev_name; } const char* get_dev_name() const { return _dev_name; }
int get_fd() const { return _fd; } int get_fd() const { return _fd; }
/** Get the interface's MTU */ /** Get the interface's MTU */
@ -48,7 +30,7 @@ class TunDevice {
private: private:
int _fd; int _fd;
std::string _dev_name; char*_dev_name;
struct pollfd _poll_fd; struct pollfd _poll_fd;
size_t _last_read_size; size_t _last_read_size;
}; };

View file

@ -9,6 +9,7 @@
#include <poll.h> #include <poll.h>
#include <errno.h> #include <errno.h>
#include <math.h> #include <math.h>
#include <stdio.h>
#include "ip_header.hpp" #include "ip_header.hpp"
@ -24,8 +25,10 @@ UdpVpn::UdpVpn()
_tun_dev.set_mtu(VpnPacket::get_tunnelled_mtu(_vpn_mtu)); _tun_dev.set_mtu(VpnPacket::get_tunnelled_mtu(_vpn_mtu));
_socket = socket(AF_INET6, SOCK_DGRAM, 0); _socket = socket(AF_INET6, SOCK_DGRAM, 0);
if(_socket < 0) if(_socket < 0) {
throw UdpVpn::InitializationError("Cannot create socket", errno, true); perror("UdpVpn: cannot create socket: ");
exit(1);
}
} }
UdpVpn::~UdpVpn() { UdpVpn::~UdpVpn() {
@ -58,8 +61,8 @@ void UdpVpn::run() {
if(rc < 0) { if(rc < 0) {
if(errno == EINTR) // Interrupt. if(errno == EINTR) // Interrupt.
continue; continue;
throw UdpVpn::NetError( perror("UdpVpn: error polling from interface: ");
"Error polling from interface", errno, true); exit(1);
} }
// ## Check periodic actions // ## Check periodic actions
@ -124,8 +127,10 @@ size_t UdpVpn::read_from_udp(char* buffer, size_t len,
nread = recvfrom(_socket, buffer, len, 0, nread = recvfrom(_socket, buffer, len, 0,
(struct sockaddr*) &peer_addr, &peer_addr_len); (struct sockaddr*) &peer_addr, &peer_addr_len);
if(nread < 0) if(nread < 0) {
throw UdpVpn::NetError("Cannot receive datagram", errno, true); perror("UdpVpn: cannot receive datagram: ");
exit(1);
}
if(nread == 0) if(nread == 0)
return 0; return 0;

View file

@ -13,23 +13,6 @@
class UdpVpn { class UdpVpn {
public: public:
class InitializationError : public MsgException {
public:
InitializationError(
const std::string& msg,
int code=0,
bool is_perror=false)
: MsgException(msg, code, is_perror) {}
};
class NetError : public MsgException {
public:
NetError(
const std::string& msg,
int code=0,
bool is_perror=false)
: MsgException(msg, code, is_perror) {}
};
UdpVpn(); UdpVpn();
virtual ~UdpVpn(); virtual ~UdpVpn();

View file

@ -1,4 +1,5 @@
#include <string.h> #include <string.h>
#include <stdio.h>
#include "UdpVpnServer.hpp" #include "UdpVpnServer.hpp"
#include "ip_header.hpp" #include "ip_header.hpp"
@ -51,7 +52,8 @@ void UdpVpnServer::bind(const struct in6_addr& bind_addr6, in_port_t port) {
rc = ::bind( rc = ::bind(
_socket, (const struct sockaddr*)&_bind_addr, sizeof(_bind_addr)); _socket, (const struct sockaddr*)&_bind_addr, sizeof(_bind_addr));
if(rc < 0) { if(rc < 0) {
throw UdpVpn::InitializationError("Cannot bind socket", errno, true); perror("UdpVpn: cannot bind socket: ");
exit(1);
} }
debugf("> Listening on port %d\n", port); debugf("> Listening on port %d\n", port);

View file

@ -2,6 +2,8 @@
#include "VpnPeer.hpp" #include "VpnPeer.hpp"
#include <string.h> #include <string.h>
#include <stdio.h>
#include <utility>
const size_t VpnPacket::VPN_HEADER_BYTES = 8; const size_t VpnPacket::VPN_HEADER_BYTES = 8;
const size_t VpnControlPacket::TLV_HEADER_BYTES = 3; const size_t VpnControlPacket::TLV_HEADER_BYTES = 3;
@ -79,8 +81,10 @@ void VpnPacket::upon_reception() {
} }
uint32_t VpnPacket::next_seqno() { uint32_t VpnPacket::next_seqno() {
if(!_peer) if(!_peer) {
throw PeerNotSet(); fprintf(stderr, "ERROR: trying to get seqno without peer.\n");
return 0;
}
return _peer->next_seqno(); return _peer->next_seqno();
} }

View file

@ -3,7 +3,6 @@
/** A packet to be transmitted or received over the VPN socket */ /** A packet to be transmitted or received over the VPN socket */
#include <stdlib.h> #include <stdlib.h>
#include <exception>
#include "ip_header.hpp" #include "ip_header.hpp"
@ -50,7 +49,6 @@ class VpnPacketTLV;
class VpnPacket { class VpnPacket {
public: public:
static const size_t VPN_HEADER_BYTES; static const size_t VPN_HEADER_BYTES;
class PeerNotSet: public std::exception {};
VpnPacket(size_t mtu, bool inbound); VpnPacket(size_t mtu, bool inbound);
~VpnPacket(); ~VpnPacket();

View file

@ -1,6 +1,7 @@
#include "UdpVpn.hpp" #include "UdpVpn.hpp"
#include "congestion_control.hpp" #include "congestion_control.hpp"
#include <stdio.h>
#include <stdint.h> #include <stdint.h>
#include <string.h> #include <string.h>
@ -69,8 +70,10 @@ size_t VpnPeer::write(const char* data, size_t len) {
nsent = sendto(_vpn->get_socket_fd(), data, len, MSG_CONFIRM, nsent = sendto(_vpn->get_socket_fd(), data, len, MSG_CONFIRM,
(const struct sockaddr*) &_ext_addr, sizeof(_ext_addr)); (const struct sockaddr*) &_ext_addr, sizeof(_ext_addr));
if(nsent < 0) if(nsent < 0) {
throw NetError("Could not send UDP packet", errno, true); perror("Could not send UDP packet: ");
exit(1);
}
_tot_bytes_sent += nsent; _tot_bytes_sent += nsent;

View file

@ -68,15 +68,6 @@ class RTTLogger {
class VpnPeer { class VpnPeer {
public: public:
class NetError : public MsgException {
public:
NetError(
const std::string& msg,
int code=0,
bool is_perror=false)
: MsgException(msg, code, is_perror) {}
};
/// Logs the loss reports sent by the remote peer /// Logs the loss reports sent by the remote peer
struct LossReports { struct LossReports {
uint32_t prev_seqno, last_seqno; uint32_t prev_seqno, last_seqno;

View file

@ -47,5 +47,6 @@ void CongestionController::update_lossbased() {
void CongestionController::update_params() { void CongestionController::update_params() {
_bandwidth = _loss_based.bandwidth; // meant to integrate other controllers _bandwidth = _loss_based.bandwidth; // meant to integrate other controllers
_bucket_max_level = _bandwidth * _peer.get_rtt().avg_rtt(); _bucket_max_level = _bandwidth * _peer.get_rtt().avg_rtt();
_bucket_level = std::min(_bucket_level, _bucket_max_level); if(_bucket_level > _bucket_max_level)
_bucket_level = _bucket_max_level;
} }

View file

@ -100,52 +100,41 @@ int main(int argc, char** argv) {
signal(SIGINT, stop_sig_handler); signal(SIGINT, stop_sig_handler);
signal(SIGUSR1, dump_sig_handler); signal(SIGUSR1, dump_sig_handler);
try { if(program_options.listen && program_options.has_peer) {
if(program_options.listen && program_options.has_peer) { fprintf(stderr,
"ERROR: Cannot be a server and a client at the same time "
"-- provide either -l or -s.\n");
return 1;
}
if(program_options.listen) {
vpn_instance = new UdpVpnServer(
program_options.bind_addr, program_options.bind_port);
} else if(program_options.has_peer) {
if(program_options.server_port == 0) {
fprintf(stderr, fprintf(stderr,
"ERROR: Cannot be a server and a client at the same time " "ERROR: A client instance must be given a server port "
"-- provide either -l or -s.\n"); "-- please provide -p.\n");
return 1;
}
if(program_options.listen) {
vpn_instance = new UdpVpnServer(
program_options.bind_addr, program_options.bind_port);
} else if(program_options.has_peer) {
if(program_options.server_port == 0) {
fprintf(stderr,
"ERROR: A client instance must be given a server port "
"-- please provide -p.\n");
return 1;
}
struct sockaddr_in6 server_addr;
server_addr.sin6_family = AF_INET6; memcpy(&server_addr.sin6_addr, &program_options.server_addr,
sizeof(server_addr.sin6_addr));
server_addr.sin6_port = htons(program_options.server_port);
vpn_instance = new UdpVpnClient(server_addr);
} else {
fprintf(stderr,
"ERROR: Must be either a server or a client "
"-- provide either -l or -s.\n");
return 1; return 1;
} }
printf("Starting to listen...\n"); struct sockaddr_in6 server_addr;
vpn_instance->run(); server_addr.sin6_family = AF_INET6; memcpy(&server_addr.sin6_addr, &program_options.server_addr,
sizeof(server_addr.sin6_addr));
server_addr.sin6_port = htons(program_options.server_port);
delete vpn_instance; vpn_instance = new UdpVpnClient(server_addr);
} else {
printf("Shutting down.\n"); fprintf(stderr,
} catch(const TunDevice::InitializationError& exn) { "ERROR: Must be either a server or a client "
fprintf(stderr, "TUN INIT ERROR: %s\n", exn.what()); "-- provide either -l or -s.\n");
} catch(const TunDevice::NetError& exn) { return 1;
fprintf(stderr, "TUN NET ERROR: %s\n", exn.what());
} catch(const UdpVpn::InitializationError& exn) {
fprintf(stderr, "VPN INIT ERROR: %s\n", exn.what());
} catch(const UdpVpn::NetError& exn) {
fprintf(stderr, "VPN NET ERROR: %s\n", exn.what());
} }
printf("Starting to listen...\n");
vpn_instance->run();
delete vpn_instance;
printf("Shutting down.\n");
return 0; return 0;
} }

View file

@ -2,6 +2,7 @@
#include <string.h> #include <string.h>
#include <stdarg.h> #include <stdarg.h>
#include <arpa/inet.h> #include <arpa/inet.h>
#include <time.h>
#include "util.hpp" #include "util.hpp"
@ -89,40 +90,3 @@ uint32_t current_us_timestamp() {
clock_gettime(CLOCK_MONOTONIC, &now); clock_gettime(CLOCK_MONOTONIC, &now);
return (now.tv_sec * 1000*1000) + (now.tv_nsec / 1000); return (now.tv_sec * 1000*1000) + (now.tv_nsec / 1000);
} }
namespace std {
size_t hash<in6_addr>::operator() (const in6_addr& addr) const {
size_t out_hash = 0;
for(int i=0; i < 4; ++i) {
uint32_t value;
memcpy((unsigned char*)(&value),
addr.s6_addr + 4*i,
4);
out_hash ^= (std::hash<uint32_t>{}(value) << 1);
}
return out_hash;
}
bool equal_to<in6_addr>::operator()(
const in6_addr& lhs, const in6_addr& rhs) const
{
return memcmp(lhs.s6_addr, rhs.s6_addr, sizeof(lhs.s6_addr)) == 0;
}
}
MsgException::MsgException(const std::string& msg, int code, bool is_perror)
: _msg(msg), _code(code)
{
_what = _msg;
if(_code != 0) {
if(is_perror) {
_what += ": ";
_what += strerror(errno);
}
char remainder[20];
sprintf(remainder, " (code %d)", _code);
_what += remainder;
}
}

View file

@ -1,7 +1,5 @@
#pragma once #pragma once
#include <exception>
#include <string>
#include <netinet/in.h> #include <netinet/in.h>
/* Debugging -- taken from babeld */ /* Debugging -- taken from babeld */
@ -79,39 +77,3 @@ uint32_t timespec_us_ellapsed(const struct timespec ref);
/** Get the current timestamp, in microseconds */ /** Get the current timestamp, in microseconds */
uint32_t current_us_timestamp(); uint32_t current_us_timestamp();
/** in6_addr hash & equality */
namespace std {
template<>
class hash<in6_addr> {
public:
size_t operator()(const in6_addr& addr) const;
};
template<>
class equal_to<in6_addr> {
public:
bool operator()(const in6_addr& lhs, const in6_addr& rhs) const;
};
}
/** MsgException -- an exception bearing a passed explanation message
*
* If `is_perror` is true, then the `strerror` corresponding message is appened
* to the message in `what()`.
*/
class MsgException : public std::exception {
public:
MsgException(const std::string& msg, int code=0, bool is_perror=false);
int errcode() const noexcept { return _code; }
const char* what() const noexcept { return _what.c_str(); };
private:
std::string _msg;
int _code;
std::string _what;
};