Implement loss-based congestion controller

Still lacks an outbound bandwidth limiter and outbound actual emission
bitrate, to avoid increasing the available bandwidth when the bandwidth
is not saturated.
This commit is contained in:
Théophile Bastian 2020-07-03 16:13:47 +02:00
parent 15f9625a6d
commit 7b5ffa4d46
10 changed files with 188 additions and 6 deletions

View File

@ -7,6 +7,7 @@ OBJS= \
VpnPeer.o \
VpnPacket.o \
TunDevice.o \
congestion_control.o \
ip_header.o util.o main.o
TARGET=congestvpn

View File

@ -195,6 +195,10 @@ void UdpVpn::receive_from_udp() {
tlv.seek_next_tlv())
{
switch(tlv.get_type()) {
case VpnPacketTLV::PAYLOAD_TYPE_LOSS_REPORT:
if(_peer)
_peer->log_loss_report(VpnTlvLossReport(tlv));
break;
case VpnPacketTLV::PAYLOAD_TYPE_RTTQ:
if(_peer)
_peer->make_rtta_for(VpnTlvRTTQ(tlv));
@ -233,9 +237,19 @@ void UdpVpn::receive_tunnelled_tlv(VpnDataPacket& packet) {
}
void UdpVpn::dump_state() const {
if(!_peer) {
printf("===== Cannot dump state, no peer yet =====\n");
return;
}
const CongestionController& congest_ctrl =
_peer->get_congestion_controller();
printf("====== State dump ======\n");
printf("Packet loss rate: %.0lf%%\n",
printf("Packet loss rate (inbound): %.0lf%%\n",
round(_peer->get_loss_logger().get_loss_rate() * 100));
printf("Packet loss rate (outbound): %.0lf%%\n",
round(_peer->get_loss_reports().loss_rate() * 100));
printf("RTT: %.02lf ms avg, %.02lf ms last [last updated: %lu ms ago]\n",
(double)_peer->get_rtt().avg_rtt() / 1e3,
(double)_peer->get_rtt().cur_rtt() / 1e3,
@ -243,5 +257,8 @@ void UdpVpn::dump_state() const {
std::chrono::steady_clock::now()
- _peer->get_rtt().get_last_update()).count()
);
printf("Available bandwidth:\n\t%s [loss based controller]\n",
human_readable_unit(congest_ctrl.get_lossbased_bandwidth(), "B")
);
printf("==== End state dump ====\n");
}

View File

@ -178,6 +178,36 @@ void VpnPacketTLV::set_type(VpnPacketTLV::PayloadType type) {
*(uint8_t*)(get_data()) = (uint8_t) type;
}
/* ========== VpnTlvLossReport ========== */
const uint32_t VpnTlvLossReport::REP_SEQNO_POS = 0,
VpnTlvLossReport::REP_LOSS_POS = 4;
VpnTlvLossReport::VpnTlvLossReport(
VpnControlPacket& packet, size_t payload_offset)
: VpnPacketTLV(packet, payload_offset) {}
VpnTlvLossReport::VpnTlvLossReport(const VpnPacketTLV& other)
: VpnPacketTLV(other) {}
VpnTlvLossReport VpnTlvLossReport::create(VpnControlPacket& packet) {
VpnTlvLossReport tlv =
VpnPacketTLV::create(packet, VpnPacketTLV::PAYLOAD_TYPE_LOSS_REPORT);
tlv.set_payload_size(8);
memset(tlv.get_payload(), 0, 8);
return tlv;
}
uint32_t VpnTlvLossReport::get_report_seqno() const {
return ntohl(*(uint32_t*)(get_payload() + REP_SEQNO_POS));
}
void VpnTlvLossReport::set_report_seqno(uint32_t seqno) {
*(uint32_t*)(get_payload() + REP_SEQNO_POS) = htonl(seqno);
}
uint32_t VpnTlvLossReport::get_losses() const {
return ntohl(*(uint32_t*)(get_payload() + REP_LOSS_POS));
}
void VpnTlvLossReport::set_losses(uint32_t losses) {
*(uint32_t*)(get_payload() + REP_LOSS_POS) = htonl(losses);
}
/* ========== VpnTlvRTTQ ========== */
VpnTlvRTTQ::VpnTlvRTTQ(VpnControlPacket& packet, size_t payload_offset)

View File

@ -178,7 +178,7 @@ class VpnPacketTLV {
enum PayloadType {
PAYLOAD_TYPE_UNDEF, ///< Undefined packet type
PAYLOAD_TYPE_RR, ///< Receiver report
PAYLOAD_TYPE_LOSS_REPORT, ///< Loss report
PAYLOAD_TYPE_REMB, ///< Receiver Estimated Maximum Bitrate
PAYLOAD_TYPE_RTTQ, ///< RTT update query
PAYLOAD_TYPE_RTTA, ///< RTT update answer
@ -246,6 +246,23 @@ class VpnPacketTLV {
size_t _tlv_pos;
};
class VpnTlvLossReport: public VpnPacketTLV {
public:
VpnTlvLossReport(VpnControlPacket& packet, size_t payload_offset);
VpnTlvLossReport(const VpnPacketTLV& other);
static VpnTlvLossReport create(VpnControlPacket& packet);
uint32_t get_report_seqno() const;
void set_report_seqno(uint32_t seqno);
uint32_t get_losses() const;
void set_losses(uint32_t losses);
private:
static const uint32_t REP_SEQNO_POS, REP_LOSS_POS;
};
class VpnTlvRTTQ: public VpnPacketTLV {
public:
VpnTlvRTTQ(VpnControlPacket& packet, size_t payload_offset);

View File

@ -1,5 +1,5 @@
#include "VpnPeer.hpp"
#include "UdpVpn.hpp"
#include "congestion_control.hpp"
#include <cstdint>
#include <cstring>
@ -10,11 +10,18 @@ const unsigned int RTTLogger::BASE_UPDATE_DELAY = 1000; // ms
VpnPeer::VpnPeer(UdpVpn* vpn, const sockaddr_in6& ext_addr,
const in6_addr& int_addr)
: _vpn(vpn), _ext_addr(ext_addr), _int_addr(int_addr), _next_send_seqno(0)
: _vpn(vpn), _ext_addr(ext_addr), _int_addr(int_addr), _next_send_seqno(0),
_congestion_controller(*this)
{
cycle_next_control();
}
void VpnPeer::make_loss_report() {
VpnTlvLossReport report = VpnTlvLossReport::create(*_next_control_packet);
report.set_report_seqno(_packet_loss.get_cur_seqno());
report.set_losses(_packet_loss.get_tot_losses());
}
void VpnPeer::cycle_next_control() {
_next_control_packet =
std::make_unique<VpnControlPacket>(_vpn->get_mtu(), false);
@ -25,6 +32,16 @@ void VpnPeer::set_int_addr(const in6_addr& int_addr) {
memcpy(&_int_addr, &int_addr, sizeof(_int_addr));
}
void VpnPeer::log_loss_report(const VpnTlvLossReport& loss_rep) {
_loss_reports.prev_seqno = _loss_reports.last_seqno;
_loss_reports.prev_losses = _loss_reports.last_losses;
_loss_reports.last_seqno = loss_rep.get_report_seqno();
_loss_reports.last_losses = loss_rep.get_losses();
_congestion_controller.update_lossbased();
}
size_t VpnPeer::write(const char* data, size_t len) {
ssize_t nsent;
@ -45,8 +62,10 @@ void VpnPeer::got_inbound_packet(const VpnPacket& packet) {
}
bool VpnPeer::send_control_packet() {
if(_rtt.update_due(!_next_control_packet->is_empty()))
if(_rtt.update_due(!_next_control_packet->is_empty())) {
VpnTlvRTTQ::create(*_next_control_packet);
make_loss_report();
}
if(!_next_control_packet->is_empty()) {
_next_control_packet->prepare_for_sending();

View File

@ -7,6 +7,7 @@
#include <chrono>
#include "util.hpp"
#include "VpnPacket.hpp"
#include "congestion_control.hpp"
class UdpVpn;
@ -77,6 +78,17 @@ class VpnPeer {
: MsgException(msg, code, is_perror) {}
};
/// Logs the loss reports sent by the remote peer
struct LossReports {
uint32_t prev_seqno, last_seqno;
uint32_t prev_losses, last_losses;
double loss_rate() const {
return (double)(last_losses - prev_losses)
/ (double)(last_seqno - prev_seqno);
}
};
VpnPeer(UdpVpn* vpn, const sockaddr_in6& ext_addr,
const in6_addr& int_addr);
@ -87,8 +99,12 @@ class VpnPeer {
const PacketLossLogger& get_loss_logger() const { return _packet_loss; }
const RTTLogger& get_rtt() const { return _rtt; }
const LossReports& get_loss_reports() const { return _loss_reports; }
const CongestionController& get_congestion_controller() const {
return _congestion_controller; }
void log_rtta(const VpnTlvRTTA& rtta) { _rtt.log(rtta); }
void log_loss_report(const VpnTlvLossReport& loss_rep);
size_t write(const char* data, size_t len);
size_t write(const VpnPacket& packet);
@ -106,7 +122,8 @@ class VpnPeer {
void make_rtta_for(const VpnTlvRTTQ& rttq);
private: // meth
void cycle_next_control(); /// Generate a fresh next control packet
void make_loss_report(); ///< Add a loss report to the next control packet
void cycle_next_control(); ///< Generate a fresh next control packet
private:
UdpVpn* _vpn;
@ -115,7 +132,9 @@ class VpnPeer {
uint32_t _next_send_seqno;
PacketLossLogger _packet_loss;
LossReports _loss_reports;
RTTLogger _rtt;
CongestionController _congestion_controller;
std::unique_ptr<VpnControlPacket> _next_control_packet;
};

27
congestion_control.cpp Normal file
View File

@ -0,0 +1,27 @@
#include "congestion_control.hpp"
#include "VpnPeer.hpp"
CongestionController::CongestionController(const VpnPeer& peer):
_peer(peer)
{
_last_seqno = _peer.get_loss_logger().get_cur_seqno();
_loss_based.bandwidth = 3e5; // 300kBps seems a good value to start with
}
void CongestionController::update_lossbased() {
const VpnPeer::LossReports& loss_rep = _peer.get_loss_reports();
uint32_t delta_seqno = loss_rep.last_seqno - loss_rep.prev_seqno;
unsigned int delta_losses = loss_rep.last_losses - loss_rep.prev_losses;
double loss_rate = (double)delta_losses / (double)delta_seqno;
if(loss_rate < 0.02) // FIXME only if the bandwidth is used
_loss_based.bandwidth *= 1.05;
else if(loss_rate >= 0.1)
_loss_based.bandwidth *= (1 - 0.5 * loss_rate);
}
uint64_t CongestionController::get_bandwidth() const {
return _loss_based.bandwidth;
}

28
congestion_control.hpp Normal file
View File

@ -0,0 +1,28 @@
#pragma once
#include <stdint.h>
class VpnPeer;
class CongestionController {
public:
struct LossBased {
uint64_t bandwidth; // bytes per second
};
CongestionController(const VpnPeer& peer);
const VpnPeer& get_peer() const { return _peer; }
void update_lossbased();
uint64_t get_bandwidth() const;
uint64_t get_lossbased_bandwidth() const {
return _loss_based.bandwidth; }
private:
const VpnPeer& _peer;
LossBased _loss_based;
uint32_t _last_seqno; // seqno at the last update time
};

View File

@ -33,6 +33,27 @@ format_address(const unsigned char *address)
return buf[i];
}
const char*
human_readable_unit(size_t value, const char* unit) {
static char buf[4][24];
static int cur_buf = 0;
static const char MULTIPLIERS[] = {' ', 'k', 'M', 'G', 'T'};
static const int NB_MULTIPLIERS = 5;
int multiplier = 0;
double div_value = value;
cur_buf = (cur_buf + 1) % 4;
while(div_value >= 1000 && multiplier < NB_MULTIPLIERS) {
div_value /= 1000;
multiplier++;
}
snprintf(buf[cur_buf], 24, "%.1lf %c%s",
div_value, MULTIPLIERS[multiplier], unit);
return buf[cur_buf];
}
namespace std {
size_t hash<in6_addr>::operator() (const in6_addr& addr) const {
size_t out_hash = 0;

View File

@ -47,6 +47,9 @@ void do_debugf(int level, const char *format, ...);
/** format_address -- taken from babeld */
const char* format_address(const unsigned char* address);
/** turns a value into a human-readable one, eg. "21.2 kB" */
const char* human_readable_unit(size_t value, const char* unit);
/** remove the upper bit from a microsecond timestamp, to conform with the
* packet header timestamp format. */