Compare commits

...

2 commits

Author SHA1 Message Date
Théophile Bastian 7b5ffa4d46 Implement loss-based congestion controller
Still lacks an outbound bandwidth limiter and outbound actual emission
bitrate, to avoid increasing the available bandwidth when the bandwidth
is not saturated.
2020-07-03 16:13:47 +02:00
Théophile Bastian 15f9625a6d Only keep track of total packet loss so far
Remove the packet loss history for last 128 packets
2020-07-02 20:41:56 +02:00
11 changed files with 223 additions and 74 deletions

View file

@ -7,6 +7,7 @@ OBJS= \
VpnPeer.o \
VpnPacket.o \
TunDevice.o \
congestion_control.o \
ip_header.o util.o main.o
TARGET=congestvpn

View file

@ -195,6 +195,10 @@ void UdpVpn::receive_from_udp() {
tlv.seek_next_tlv())
{
switch(tlv.get_type()) {
case VpnPacketTLV::PAYLOAD_TYPE_LOSS_REPORT:
if(_peer)
_peer->log_loss_report(VpnTlvLossReport(tlv));
break;
case VpnPacketTLV::PAYLOAD_TYPE_RTTQ:
if(_peer)
_peer->make_rtta_for(VpnTlvRTTQ(tlv));
@ -233,9 +237,19 @@ void UdpVpn::receive_tunnelled_tlv(VpnDataPacket& packet) {
}
void UdpVpn::dump_state() const {
if(!_peer) {
printf("===== Cannot dump state, no peer yet =====\n");
return;
}
const CongestionController& congest_ctrl =
_peer->get_congestion_controller();
printf("====== State dump ======\n");
printf("Packet loss rate: %.0lf%%\n",
printf("Packet loss rate (inbound): %.0lf%%\n",
round(_peer->get_loss_logger().get_loss_rate() * 100));
printf("Packet loss rate (outbound): %.0lf%%\n",
round(_peer->get_loss_reports().loss_rate() * 100));
printf("RTT: %.02lf ms avg, %.02lf ms last [last updated: %lu ms ago]\n",
(double)_peer->get_rtt().avg_rtt() / 1e3,
(double)_peer->get_rtt().cur_rtt() / 1e3,
@ -243,5 +257,8 @@ void UdpVpn::dump_state() const {
std::chrono::steady_clock::now()
- _peer->get_rtt().get_last_update()).count()
);
printf("Available bandwidth:\n\t%s [loss based controller]\n",
human_readable_unit(congest_ctrl.get_lossbased_bandwidth(), "B")
);
printf("==== End state dump ====\n");
}

View file

@ -178,6 +178,36 @@ void VpnPacketTLV::set_type(VpnPacketTLV::PayloadType type) {
*(uint8_t*)(get_data()) = (uint8_t) type;
}
/* ========== VpnTlvLossReport ========== */
const uint32_t VpnTlvLossReport::REP_SEQNO_POS = 0,
VpnTlvLossReport::REP_LOSS_POS = 4;
VpnTlvLossReport::VpnTlvLossReport(
VpnControlPacket& packet, size_t payload_offset)
: VpnPacketTLV(packet, payload_offset) {}
VpnTlvLossReport::VpnTlvLossReport(const VpnPacketTLV& other)
: VpnPacketTLV(other) {}
VpnTlvLossReport VpnTlvLossReport::create(VpnControlPacket& packet) {
VpnTlvLossReport tlv =
VpnPacketTLV::create(packet, VpnPacketTLV::PAYLOAD_TYPE_LOSS_REPORT);
tlv.set_payload_size(8);
memset(tlv.get_payload(), 0, 8);
return tlv;
}
uint32_t VpnTlvLossReport::get_report_seqno() const {
return ntohl(*(uint32_t*)(get_payload() + REP_SEQNO_POS));
}
void VpnTlvLossReport::set_report_seqno(uint32_t seqno) {
*(uint32_t*)(get_payload() + REP_SEQNO_POS) = htonl(seqno);
}
uint32_t VpnTlvLossReport::get_losses() const {
return ntohl(*(uint32_t*)(get_payload() + REP_LOSS_POS));
}
void VpnTlvLossReport::set_losses(uint32_t losses) {
*(uint32_t*)(get_payload() + REP_LOSS_POS) = htonl(losses);
}
/* ========== VpnTlvRTTQ ========== */
VpnTlvRTTQ::VpnTlvRTTQ(VpnControlPacket& packet, size_t payload_offset)

View file

@ -178,7 +178,7 @@ class VpnPacketTLV {
enum PayloadType {
PAYLOAD_TYPE_UNDEF, ///< Undefined packet type
PAYLOAD_TYPE_RR, ///< Receiver report
PAYLOAD_TYPE_LOSS_REPORT, ///< Loss report
PAYLOAD_TYPE_REMB, ///< Receiver Estimated Maximum Bitrate
PAYLOAD_TYPE_RTTQ, ///< RTT update query
PAYLOAD_TYPE_RTTA, ///< RTT update answer
@ -246,6 +246,23 @@ class VpnPacketTLV {
size_t _tlv_pos;
};
class VpnTlvLossReport: public VpnPacketTLV {
public:
VpnTlvLossReport(VpnControlPacket& packet, size_t payload_offset);
VpnTlvLossReport(const VpnPacketTLV& other);
static VpnTlvLossReport create(VpnControlPacket& packet);
uint32_t get_report_seqno() const;
void set_report_seqno(uint32_t seqno);
uint32_t get_losses() const;
void set_losses(uint32_t losses);
private:
static const uint32_t REP_SEQNO_POS, REP_LOSS_POS;
};
class VpnTlvRTTQ: public VpnPacketTLV {
public:
VpnTlvRTTQ(VpnControlPacket& packet, size_t payload_offset);

View file

@ -1,5 +1,5 @@
#include "VpnPeer.hpp"
#include "UdpVpn.hpp"
#include "congestion_control.hpp"
#include <cstdint>
#include <cstring>
@ -10,11 +10,18 @@ const unsigned int RTTLogger::BASE_UPDATE_DELAY = 1000; // ms
VpnPeer::VpnPeer(UdpVpn* vpn, const sockaddr_in6& ext_addr,
const in6_addr& int_addr)
: _vpn(vpn), _ext_addr(ext_addr), _int_addr(int_addr), _next_send_seqno(0)
: _vpn(vpn), _ext_addr(ext_addr), _int_addr(int_addr), _next_send_seqno(0),
_congestion_controller(*this)
{
cycle_next_control();
}
void VpnPeer::make_loss_report() {
VpnTlvLossReport report = VpnTlvLossReport::create(*_next_control_packet);
report.set_report_seqno(_packet_loss.get_cur_seqno());
report.set_losses(_packet_loss.get_tot_losses());
}
void VpnPeer::cycle_next_control() {
_next_control_packet =
std::make_unique<VpnControlPacket>(_vpn->get_mtu(), false);
@ -25,6 +32,16 @@ void VpnPeer::set_int_addr(const in6_addr& int_addr) {
memcpy(&_int_addr, &int_addr, sizeof(_int_addr));
}
void VpnPeer::log_loss_report(const VpnTlvLossReport& loss_rep) {
_loss_reports.prev_seqno = _loss_reports.last_seqno;
_loss_reports.prev_losses = _loss_reports.last_losses;
_loss_reports.last_seqno = loss_rep.get_report_seqno();
_loss_reports.last_losses = loss_rep.get_losses();
_congestion_controller.update_lossbased();
}
size_t VpnPeer::write(const char* data, size_t len) {
ssize_t nsent;
@ -45,8 +62,10 @@ void VpnPeer::got_inbound_packet(const VpnPacket& packet) {
}
bool VpnPeer::send_control_packet() {
if(_rtt.update_due(!_next_control_packet->is_empty()))
if(_rtt.update_due(!_next_control_packet->is_empty())) {
VpnTlvRTTQ::create(*_next_control_packet);
make_loss_report();
}
if(!_next_control_packet->is_empty()) {
_next_control_packet->prepare_for_sending();
@ -63,44 +82,57 @@ void VpnPeer::make_rtta_for(const VpnTlvRTTQ& rttq) {
rtta.set_recv_ts(rttq.get_packet().get_reception_timestamp());
}
PacketLossLogger::PacketLossLogger() : _cur_seqno(0) {}
PacketLossLogger::PacketLossLogger() :
_cur_seqno(0), _tot_losses(0), _last_window_losses(0), _win_start_losses(0)
{}
void PacketLossLogger::log_packet(uint32_t seqno) {
uint32_t m_seqno = seqno % PACKET_LOSS_HISTSIZE;
int64_t diff = (int64_t)seqno - _cur_seqno;
if(diff == 1) {
_cur_seqno++;
_packet_loss_hist.reset(m_seqno);
maybe_start_window();
while(_received_ahead.test((_cur_seqno + 1) % PACKET_LOST_AFTER)) {
_cur_seqno++;
_packet_loss_hist.reset(_cur_seqno % PACKET_LOSS_HISTSIZE);
_received_ahead.reset(_cur_seqno % PACKET_LOST_AFTER);
maybe_start_window();
}
} else if(LIKELY(diff > 1)) {
if(diff < PACKET_LOST_AFTER)
_received_ahead.set(seqno % PACKET_LOST_AFTER);
else if(diff < PACKET_LOSS_HISTSIZE) {
else if(diff < PACKET_LOSS_WINDOW) {
// Packet too much forwards -- consider _cur_seqno lost
for(int offs=1; offs < PACKET_LOST_AFTER; ++offs) {
_packet_loss_hist[(_cur_seqno + offs) % PACKET_LOSS_HISTSIZE] =
!_received_ahead[(_cur_seqno + offs) % PACKET_LOST_AFTER];
if(_cur_seqno % PACKET_LOSS_WINDOW > seqno % PACKET_LOSS_WINDOW) {
// This loss crosses a window border
for(int offs=0; offs < PACKET_LOST_AFTER; ++offs) {
maybe_start_window(offs);
if(!_received_ahead[(_cur_seqno + offs) % PACKET_LOST_AFTER])
_tot_losses++;
}
} else {
_tot_losses += PACKET_LOST_AFTER - _received_ahead.count();
}
_received_ahead.reset();
_cur_seqno = seqno;
_packet_loss_hist.reset(m_seqno);
} else
reboot(); // This is a huge gap -- reboot
} else {
if(diff < - 2*PACKET_LOSS_HISTSIZE)
if(diff < - 2*PACKET_LOSS_WINDOW)
reboot(); // this is too much backwards -- something's wrong, reboot
// else: ignore, we've moved forward and counted the packet as lost
}
}
void PacketLossLogger::reboot() {
_packet_loss_hist.reset();
_received_ahead.reset();
// _tot_losses unchanged
}
void PacketLossLogger::maybe_start_window(int offs) {
if(_cur_seqno + offs % PACKET_LOSS_WINDOW == 0) {
_last_window_losses = _win_start_losses;
_win_start_losses = _tot_losses;
}
}
RTTLogger::RTTLogger() :

View file

@ -7,21 +7,19 @@
#include <chrono>
#include "util.hpp"
#include "VpnPacket.hpp"
#include "congestion_control.hpp"
class UdpVpn;
const int PACKET_LOSS_HISTSIZE = 128, PACKET_LOST_AFTER = 8;
const int PACKET_LOSS_WINDOW = 128, PACKET_LOST_AFTER = 8;
class PacketLossLogger {
public:
PacketLossLogger();
void log_packet(uint32_t seqno);
double get_loss_rate() const {
return (double)_packet_loss_hist.count() / PACKET_LOSS_HISTSIZE;
}
const std::bitset<PACKET_LOSS_HISTSIZE> get_loss_hist() const {
return _packet_loss_hist;
return (double)(_win_start_losses - _last_window_losses)
/ (double)PACKET_LOSS_WINDOW;
}
const std::bitset<PACKET_LOST_AFTER> get_received_ahead() const {
@ -29,13 +27,18 @@ class PacketLossLogger {
}
uint32_t get_cur_seqno() const { return _cur_seqno; }
unsigned int get_tot_losses() const { return _tot_losses; }
private:
void reboot(); ///< completely reset the internal state
std::bitset<PACKET_LOSS_HISTSIZE> _packet_loss_hist;
/// roll loss window values if `_cur_seqno + offs` is a window start.
void maybe_start_window(int offs=0);
std::bitset<PACKET_LOST_AFTER> _received_ahead;
uint32_t _cur_seqno;
unsigned int _tot_losses;
unsigned int _last_window_losses, _win_start_losses;
};
/** Round-trip time logger. All timestamps/delays are in microseconds. */
@ -75,6 +78,17 @@ class VpnPeer {
: MsgException(msg, code, is_perror) {}
};
/// Logs the loss reports sent by the remote peer
struct LossReports {
uint32_t prev_seqno, last_seqno;
uint32_t prev_losses, last_losses;
double loss_rate() const {
return (double)(last_losses - prev_losses)
/ (double)(last_seqno - prev_seqno);
}
};
VpnPeer(UdpVpn* vpn, const sockaddr_in6& ext_addr,
const in6_addr& int_addr);
@ -83,10 +97,14 @@ class VpnPeer {
void set_int_addr(const in6_addr& int_addr);
const PacketLossLogger& get_loss_logger() { return _packet_loss; }
const RTTLogger& get_rtt() { return _rtt; }
const PacketLossLogger& get_loss_logger() const { return _packet_loss; }
const RTTLogger& get_rtt() const { return _rtt; }
const LossReports& get_loss_reports() const { return _loss_reports; }
const CongestionController& get_congestion_controller() const {
return _congestion_controller; }
void log_rtta(const VpnTlvRTTA& rtta) { _rtt.log(rtta); }
void log_loss_report(const VpnTlvLossReport& loss_rep);
size_t write(const char* data, size_t len);
size_t write(const VpnPacket& packet);
@ -104,7 +122,8 @@ class VpnPeer {
void make_rtta_for(const VpnTlvRTTQ& rttq);
private: // meth
void cycle_next_control(); /// Generate a fresh next control packet
void make_loss_report(); ///< Add a loss report to the next control packet
void cycle_next_control(); ///< Generate a fresh next control packet
private:
UdpVpn* _vpn;
@ -113,7 +132,9 @@ class VpnPeer {
uint32_t _next_send_seqno;
PacketLossLogger _packet_loss;
LossReports _loss_reports;
RTTLogger _rtt;
CongestionController _congestion_controller;
std::unique_ptr<VpnControlPacket> _next_control_packet;
};

27
congestion_control.cpp Normal file
View file

@ -0,0 +1,27 @@
#include "congestion_control.hpp"
#include "VpnPeer.hpp"
CongestionController::CongestionController(const VpnPeer& peer):
_peer(peer)
{
_last_seqno = _peer.get_loss_logger().get_cur_seqno();
_loss_based.bandwidth = 3e5; // 300kBps seems a good value to start with
}
void CongestionController::update_lossbased() {
const VpnPeer::LossReports& loss_rep = _peer.get_loss_reports();
uint32_t delta_seqno = loss_rep.last_seqno - loss_rep.prev_seqno;
unsigned int delta_losses = loss_rep.last_losses - loss_rep.prev_losses;
double loss_rate = (double)delta_losses / (double)delta_seqno;
if(loss_rate < 0.02) // FIXME only if the bandwidth is used
_loss_based.bandwidth *= 1.05;
else if(loss_rate >= 0.1)
_loss_based.bandwidth *= (1 - 0.5 * loss_rate);
}
uint64_t CongestionController::get_bandwidth() const {
return _loss_based.bandwidth;
}

28
congestion_control.hpp Normal file
View file

@ -0,0 +1,28 @@
#pragma once
#include <stdint.h>
class VpnPeer;
class CongestionController {
public:
struct LossBased {
uint64_t bandwidth; // bytes per second
};
CongestionController(const VpnPeer& peer);
const VpnPeer& get_peer() const { return _peer; }
void update_lossbased();
uint64_t get_bandwidth() const;
uint64_t get_lossbased_bandwidth() const {
return _loss_based.bandwidth; }
private:
const VpnPeer& _peer;
LossBased _loss_based;
uint32_t _last_seqno; // seqno at the last update time
};

View file

@ -1,48 +0,0 @@
#include <cstdio>
#include <vector>
#include "VpnPeer.hpp"
template<int len>
std::string bitset_to_string(
const std::bitset<len>& bs,
char c_f, char c_t, size_t mark_pos, char cm_f, char cm_t)
{
std::string out;
for(size_t pos=0; pos < len; ++pos) {
if(pos == mark_pos)
out += bs[pos] ? cm_t : cm_f;
else
out += bs[pos] ? c_t : c_f;
}
return out;
}
void dump(PacketLossLogger& pllog, int received) {
printf("%03d, %.03lf ## %s ## %s\n",
received,
pllog.get_loss_rate(),
bitset_to_string<PACKET_LOSS_HISTSIZE>(
pllog.get_loss_hist(), '_', 'X',
pllog.get_cur_seqno() % PACKET_LOSS_HISTSIZE,
'|', '#').c_str(),
bitset_to_string<PACKET_LOST_AFTER>(
pllog.get_received_ahead(), '_', 'X',
pllog.get_cur_seqno() % PACKET_LOST_AFTER,
'|', '#').c_str());
}
int main(void) {
PacketLossLogger pllog;
std::vector<int> sequence({
1, 2, 3, 5, 6, 4, 8, 7, 8, 9,
12, 13, 14, 15, 16, 17, 18, 19, 20, 10, 21
});
dump(pllog, 0);
for(auto val=sequence.begin(); val != sequence.end(); ++val) {
pllog.log_packet(*val);
dump(pllog, *val);
}
return 0;
}

View file

@ -33,6 +33,27 @@ format_address(const unsigned char *address)
return buf[i];
}
const char*
human_readable_unit(size_t value, const char* unit) {
static char buf[4][24];
static int cur_buf = 0;
static const char MULTIPLIERS[] = {' ', 'k', 'M', 'G', 'T'};
static const int NB_MULTIPLIERS = 5;
int multiplier = 0;
double div_value = value;
cur_buf = (cur_buf + 1) % 4;
while(div_value >= 1000 && multiplier < NB_MULTIPLIERS) {
div_value /= 1000;
multiplier++;
}
snprintf(buf[cur_buf], 24, "%.1lf %c%s",
div_value, MULTIPLIERS[multiplier], unit);
return buf[cur_buf];
}
namespace std {
size_t hash<in6_addr>::operator() (const in6_addr& addr) const {
size_t out_hash = 0;

View file

@ -47,6 +47,9 @@ void do_debugf(int level, const char *format, ...);
/** format_address -- taken from babeld */
const char* format_address(const unsigned char* address);
/** turns a value into a human-readable one, eg. "21.2 kB" */
const char* human_readable_unit(size_t value, const char* unit);
/** remove the upper bit from a microsecond timestamp, to conform with the
* packet header timestamp format. */