congestvpn/VpnPacket.cpp

242 lines
7 KiB
C++

#include "VpnPacket.hpp"
#include "VpnPeer.hpp"
#include <string.h>
const size_t VpnPacket::VPN_HEADER_BYTES = 8;
const size_t VpnControlPacket::TLV_HEADER_BYTES = 3;
static const size_t OUTER_HEADERS_BYTES =
40 /* IPv6 header */ + 8 /* UDP header */;
// We use a TUN device, hence we don't have a layer 2 header.
static const int
DATA_SEQNO_POS = 0,
DATA_CTRLBIT_POS = 4,
DATA_TIMESTAMP_POS = 4;
VpnPacket::VpnPacket(size_t mtu, bool inbound)
: _peer(nullptr), _inbound(inbound), _data_space(mtu-OUTER_HEADERS_BYTES),
_data_size(VPN_HEADER_BYTES), _reception_timestamp(0)
{
_data = std::unique_ptr<char[]>(new char[mtu - OUTER_HEADERS_BYTES]);
}
VpnPacket::~VpnPacket() {}
VpnPacket::VpnPacket(VpnPacket&& move_from) :
_peer(move_from._peer),
_inbound(move_from._inbound),
_data(std::move(move_from._data)),
_data_space(move_from._data_space),
_data_size(move_from._data_size),
_reception_timestamp(move_from._reception_timestamp)
{}
size_t VpnPacket::get_tunnelled_mtu(size_t udp_mtu) {
return udp_mtu - OUTER_HEADERS_BYTES - VPN_HEADER_BYTES;
}
void VpnPacket::set_peer(VpnPeer* peer) {
_peer = peer;
if(_peer && _inbound)
_peer->got_inbound_packet(*this);
}
uint32_t VpnPacket::get_seqno() const {
return ntohl(*(uint32_t*)(_data.get() + DATA_SEQNO_POS));
}
uint32_t VpnPacket::get_sending_timestamp() const {
return ntohl(
*(uint32_t*)(_data.get() + DATA_TIMESTAMP_POS) & 0x7fffffffUL
);
}
bool VpnPacket::is_control() const {
return *(unsigned char*)(_data.get() + DATA_CTRLBIT_POS) & 0x80;
}
void VpnPacket::set_control(bool is_control) {
unsigned char* ctrl_field =
(unsigned char*) (_data.get() + DATA_CTRLBIT_POS);
*ctrl_field &= 0x7f;
if(is_control)
*ctrl_field |= 0x80;
}
void VpnPacket::prepare_for_sending() {
uint32_t* ts_field = (uint32_t*) (_data.get() + DATA_TIMESTAMP_POS);
*ts_field &= htonl(0x80000000UL);
*(uint32_t*)(_data.get() + DATA_SEQNO_POS) = htonl(next_seqno());
*ts_field |= htonl(to_us_timestamp(current_us_timestamp()));
}
void VpnPacket::upon_reception() {
_reception_timestamp = to_us_timestamp(current_us_timestamp());
}
uint32_t VpnPacket::next_seqno() {
if(!_peer)
throw PeerNotSet();
return _peer->next_seqno();
}
VpnControlPacket::VpnControlPacket(size_t mtu, bool inbound)
: VpnPacket(mtu, inbound)
{
set_control(true);
}
VpnControlPacket::VpnControlPacket(VpnPacket&& move_from)
: VpnPacket(std::move(move_from))
{}
VpnPacketTLV VpnControlPacket::first_tlv() {
return VpnPacketTLV(*this, 0);
}
VpnDataPacket::VpnDataPacket(size_t mtu, bool inbound)
: VpnPacket(mtu, inbound), _ipv6_parsed(false)
{
set_control(false);
}
VpnDataPacket::VpnDataPacket(VpnPacket&& move_from)
: VpnPacket(std::move(move_from)), _ipv6_parsed(false)
{}
bool VpnDataPacket::parse_as_ipv6(bool reparse) {
if(_ipv6_parsed && !reparse)
return true;
_ipv6_parsed =
parse_ipv6_header(get_payload(), get_payload_size(), _ipv6_header);
return _ipv6_parsed;
}
VpnPacketTLV::VpnPacketTLV(VpnControlPacket& packet, size_t payload_offset)
: _packet(packet), _tlv_pos(payload_offset)
{}
VpnPacketTLV::VpnPacketTLV(const VpnPacketTLV& other) :
_packet(other._packet), _tlv_pos(other._tlv_pos)
{}
VpnPacketTLV VpnPacketTLV::create(
VpnControlPacket& packet, VpnPacketTLV::PayloadType type)
{
VpnPacketTLV tlv = VpnPacketTLV(packet, packet.get_payload_size());
packet.increase_payload_size(VpnControlPacket::TLV_HEADER_BYTES);
char* data = tlv.get_data();
data[0] = type;
*(uint16_t*)(data+1) = 0; // Set len to 0
return tlv;
}
VpnPacketTLV VpnPacketTLV::next_tlv() {
size_t next_offset = _tlv_pos + get_payload_size();
return VpnPacketTLV(_packet, next_offset);
}
void VpnPacketTLV::seek_next_tlv() {
_tlv_pos =
_tlv_pos + VpnControlPacket::TLV_HEADER_BYTES + get_payload_size();
}
uint16_t VpnPacketTLV::get_payload_size() const {
return *(uint16_t*)(get_data() + 1);
}
void VpnPacketTLV::set_payload_size(uint16_t size) {
uint16_t* data_size_ptr = (uint16_t*)(get_data() + 1);
uint16_t old_size = *data_size_ptr;
*data_size_ptr = size;
_packet.increase_payload_size(size - old_size);
}
uint16_t VpnPacketTLV::get_payload_space() const {
return _packet.get_payload_space()
- _packet.get_payload_size()
+ get_payload_size();
}
VpnPacketTLV::PayloadType VpnPacketTLV::get_type() const {
return (PayloadType)(*(uint8_t*)(get_data()));
}
void VpnPacketTLV::set_type(VpnPacketTLV::PayloadType type) {
*(uint8_t*)(get_data()) = (uint8_t) type;
}
/* ========== VpnTlvLossReport ========== */
const uint32_t VpnTlvLossReport::REP_SEQNO_POS = 0,
VpnTlvLossReport::REP_LOSS_POS = 4;
VpnTlvLossReport::VpnTlvLossReport(
VpnControlPacket& packet, size_t payload_offset)
: VpnPacketTLV(packet, payload_offset) {}
VpnTlvLossReport::VpnTlvLossReport(const VpnPacketTLV& other)
: VpnPacketTLV(other) {}
VpnTlvLossReport VpnTlvLossReport::create(VpnControlPacket& packet) {
VpnTlvLossReport tlv =
VpnPacketTLV::create(packet, VpnPacketTLV::PAYLOAD_TYPE_LOSS_REPORT);
tlv.set_payload_size(8);
memset(tlv.get_payload(), 0, 8);
return tlv;
}
uint32_t VpnTlvLossReport::get_report_seqno() const {
return ntohl(*(uint32_t*)(get_payload() + REP_SEQNO_POS));
}
void VpnTlvLossReport::set_report_seqno(uint32_t seqno) {
*(uint32_t*)(get_payload() + REP_SEQNO_POS) = htonl(seqno);
}
uint32_t VpnTlvLossReport::get_losses() const {
return ntohl(*(uint32_t*)(get_payload() + REP_LOSS_POS));
}
void VpnTlvLossReport::set_losses(uint32_t losses) {
*(uint32_t*)(get_payload() + REP_LOSS_POS) = htonl(losses);
}
/* ========== VpnTlvRTTQ ========== */
VpnTlvRTTQ::VpnTlvRTTQ(VpnControlPacket& packet, size_t payload_offset)
: VpnPacketTLV(packet, payload_offset) {}
VpnTlvRTTQ::VpnTlvRTTQ(const VpnPacketTLV& other)
: VpnPacketTLV(other) {}
VpnTlvRTTQ VpnTlvRTTQ::create(VpnControlPacket& packet) {
return VpnPacketTLV::create(packet, VpnPacketTLV::PAYLOAD_TYPE_RTTQ);
}
/* ========== VpnTlvRTTA ========== */
const uint32_t VpnTlvRTTA::EXP_TS_POS = 0, VpnTlvRTTA::RECV_TS_POS = 4;
VpnTlvRTTA::VpnTlvRTTA(VpnControlPacket& packet, size_t payload_offset)
: VpnPacketTLV(packet, payload_offset) {}
VpnTlvRTTA::VpnTlvRTTA(const VpnPacketTLV& other)
: VpnPacketTLV(other) {}
VpnTlvRTTA VpnTlvRTTA::create(VpnControlPacket& packet) {
VpnTlvRTTA tlv =
VpnPacketTLV::create(packet, VpnPacketTLV::PAYLOAD_TYPE_RTTA);
tlv.set_payload_size(8);
memset(tlv.get_payload(), 0, 8);
return tlv;
}
uint32_t VpnTlvRTTA::get_exp_ts() const {
return ntohl(*(uint32_t*)(get_payload() + EXP_TS_POS));
}
void VpnTlvRTTA::set_exp_ts(uint32_t ts) {
*(uint32_t*)(get_payload() + EXP_TS_POS) = htonl(ts);
}
uint32_t VpnTlvRTTA::get_recv_ts() const {
return ntohl(*(uint32_t*)(get_payload() + RECV_TS_POS));
}
void VpnTlvRTTA::set_recv_ts(uint32_t ts) {
*(uint32_t*)(get_payload() + RECV_TS_POS) = htonl(ts);
}