TunDevice: add read/write, cleaner error handling

This commit is contained in:
Théophile Bastian 2020-06-03 14:54:01 +02:00
parent 1f96a34a37
commit 53444e725c
4 changed files with 83 additions and 11 deletions

View file

@ -4,18 +4,23 @@
#include <sys/ioctl.h> #include <sys/ioctl.h>
#include <fcntl.h> #include <fcntl.h>
#include <unistd.h> #include <unistd.h>
#include <errno.h>
#include <cstdlib> #include <cstdlib>
#include <cstring> #include <cstring>
#include "TunDevice.hpp" #include "TunDevice.hpp"
static const size_t TUN_MTU = 1500; // TODO determine this cleanly
TunDevice::TunDevice(const std::string& dev) TunDevice::TunDevice(const std::string& dev)
{ {
struct ifreq ifr; struct ifreq ifr;
int fd, err; int fd;
if( (fd = open("/dev/net/tun", O_RDWR)) < 0 ) if( (fd = open("/dev/net/tun", O_RDWR)) < 0 ) {
throw TunDevice::InitializationError("Cannot open /dev/net/tun", fd); throw TunDevice::InitializationError(
"Cannot open /dev/net/tun", errno, true);
}
memset(&ifr, 0, sizeof(ifr)); memset(&ifr, 0, sizeof(ifr));
@ -31,16 +36,50 @@ TunDevice::TunDevice(const std::string& dev)
strncpy(ifr.ifr_name, dev.c_str(), IFNAMSIZ-1); strncpy(ifr.ifr_name, dev.c_str(), IFNAMSIZ-1);
} }
if( (err = ioctl(fd, TUNSETIFF, (void *) &ifr)) < 0 ){ if(ioctl(fd, TUNSETIFF, (void *) &ifr) < 0){
close(fd); close(fd);
throw TunDevice::InitializationError( throw TunDevice::InitializationError(
"Tunnel interface failed [TUNSETIFF]", err "Tunnel interface failed [TUNSETIFF]", errno, true);
);
} }
_dev_name = ifr.ifr_name; _dev_name = ifr.ifr_name;
_fd = fd; _fd = fd;
// The device is now fully set up
_poll_fd.fd = _fd;
_poll_fd.events = POLLIN;
} }
TunDevice::~TunDevice() { TunDevice::~TunDevice() {
close(_fd); close(_fd);
} }
size_t TunDevice::read_packet(char* read_buffer, size_t buf_size, int timeout) {
int poll_rc = poll(&_poll_fd, 1, timeout);
if(poll_rc < 0) {
if(errno == EINTR) // Interrupt.
return 0;
throw TunDevice::NetError(
"Error polling from interface", errno, true);
}
else if(poll_rc == 0 || (_poll_fd.revents & POLLIN) == 0) {
// Nothing to read
return 0;
}
int nread = read(_fd, read_buffer, buf_size);
if(nread < 0) {
throw TunDevice::NetError(
"Error reading from interface", errno, true);
}
_last_read_size = nread;
return _last_read_size;
}
size_t TunDevice::write_packet(const char* data, size_t len) {
int nwritten = write(_fd, data, len);
if(nwritten < 0) {
throw TunDevice::NetError(
"Error writing to interface: ", errno, true);
}
return nwritten;
}

View file

@ -1,14 +1,26 @@
#pragma once #pragma once
#include <string> #include <string>
#include <poll.h>
#include "util.hpp" #include "util.hpp"
class TunDevice { class TunDevice {
public: public:
class InitializationError : public MsgException { class InitializationError : public MsgException {
public: public:
InitializationError(const std::string& msg, int code=0) InitializationError(
: MsgException(msg, code) {} const std::string& msg,
int code=0,
bool is_perror=false)
: MsgException(msg, code, is_perror) {}
};
class NetError : public MsgException {
public:
NetError(
const std::string& msg,
int code=0,
bool is_perror=false)
: MsgException(msg, code, is_perror) {}
}; };
TunDevice(const std::string& dev); TunDevice(const std::string& dev);
@ -16,7 +28,18 @@ class TunDevice {
const std::string& get_dev_name() const { return _dev_name; } const std::string& get_dev_name() const { return _dev_name; }
int get_fd() const { return _fd; } int get_fd() const { return _fd; }
/* Reads a packet from the device.
* Timeouts after `timeout` ms, or never if `timeout < 0`.
* Upon timeout, returns 0.
*/
size_t read_packet(char* read_buffer, size_t buf_size, int timeout=-1);
size_t write_packet(const char* data, size_t len);
private: private:
int _fd; int _fd;
std::string _dev_name; std::string _dev_name;
struct pollfd _poll_fd;
size_t _last_read_size;
}; };

View file

@ -1,13 +1,19 @@
#include <cstdio> #include <cstdio>
#include <string.h>
#include "util.hpp" #include "util.hpp"
MsgException::MsgException(const std::string& msg, int code) MsgException::MsgException(const std::string& msg, int code, bool is_perror)
: _msg(msg), _code(code) : _msg(msg), _code(code)
{ {
_what = _msg; _what = _msg;
if(_code != 0) { if(_code != 0) {
if(is_perror) {
_what += ": ";
_what += strerror(errno);
}
char remainder[20]; char remainder[20];
sprintf(remainder, " (code %d)", _code); sprintf(remainder, " (code %d)", _code);
_what += remainder; _what += remainder;

View file

@ -3,10 +3,14 @@
#include <exception> #include <exception>
#include <string> #include <string>
// MsgException -- an exception bearing a passed explanation message /** MsgException -- an exception bearing a passed explanation message
*
* If `is_perror` is true, then the `strerror` corresponding message is appened
* to the message in `what()`.
*/
class MsgException : public std::exception { class MsgException : public std::exception {
public: public:
MsgException(const std::string& msg, int code=0); MsgException(const std::string& msg, int code=0, bool is_perror=false);
int errcode() const noexcept { return _code; } int errcode() const noexcept { return _code; }
const char* what() const noexcept { return _what.c_str(); }; const char* what() const noexcept { return _what.c_str(); };