Radish alpha
h
Radicle Heartwood Protocol & Stack
Radicle
Git (anonymous pull)
Log in to clone via SSH
node: Implement token-bucket rate limitter
cloudhead committed 2 years ago
commit fb4a4e0079261e5671a75ccf4397cfa6ef9ac3d3
parent 260b1a428cb719b02bbba1bf50f16b5666c59760
4 files changed +242 -18
modified radicle-node/src/service.rs
@@ -3,15 +3,16 @@
#![allow(clippy::collapsible_if)]
pub mod filter;
pub mod io;
+
pub mod limitter;
pub mod message;
pub mod session;
pub mod tracking;

use std::collections::hash_map::Entry;
use std::collections::{BTreeMap, HashMap, HashSet};
+
use std::fmt;
use std::ops::{Deref, DerefMut};
use std::sync::Arc;
-
use std::{fmt, net};

use crossbeam_channel as chan;
use fastrand::Rng;
@@ -29,7 +30,7 @@ use crate::identity::IdentityError;
use crate::identity::{Doc, Id};
use crate::node::routing;
use crate::node::routing::InsertResult;
-
use crate::node::{Address, Alias, Features, FetchResult, Seed, Seeds};
+
use crate::node::{Address, Alias, Features, FetchResult, HostName, Seed, Seeds};
use crate::prelude::*;
use crate::runtime::Emitter;
use crate::service::message::{Announcement, AnnouncementMessage, Ping};
@@ -48,6 +49,7 @@ pub use crate::service::session::Session;

use self::gossip::Gossip;
use self::io::Outbox;
+
use self::limitter::RateLimiter;
use self::message::InventoryAnnouncement;
use self::tracking::NamespacesError;

@@ -204,6 +206,8 @@ pub struct Service<R, A, S, G> {
    rng: Rng,
    /// Fetch requests initiated by user, which are waiting for results.
    fetch_reqs: HashMap<(Id, NodeId), chan::Sender<FetchResult>>,
+
    /// Request/connection rate limitter.
+
    limiter: RateLimiter,
    /// Current tracked repository bloom filter.
    filter: Filter,
    /// Last time the service was idle.
@@ -268,6 +272,7 @@ where
            routing,
            gossip: Gossip::default(),
            outbox: Outbox::default(),
+
            limiter: RateLimiter::default(),
            sessions,
            fetch_reqs: HashMap::new(),
            filter: Filter::empty(),
@@ -710,8 +715,19 @@ where
        }
    }

-
    pub fn accepted(&mut self, _addr: net::SocketAddr) {
-
        // Inbound connection attempt.
+
    /// Inbound connection attempt.
+
    pub fn accepted(&mut self, addr: Address) -> bool {
+
        // Always accept trusted connections.
+
        if addr.is_trusted() {
+
            return true;
+
        }
+
        let host: HostName = addr.into();
+

+
        if self.limiter.limit(host.clone(), &Link::Inbound, self.clock) {
+
            trace!(target: "service", "Rate limitting inbound connection from {host}..");
+
            return false;
+
        }
+
        true
    }

    pub fn attempted(&mut self, nid: NodeId, addr: Address) {
@@ -1097,6 +1113,13 @@ where
            warn!(target: "service", "Session not found for {remote}");
            return Ok(());
        };
+
        if self
+
            .limiter
+
            .limit(peer.addr.clone().into(), &peer.link, self.clock)
+
        {
+
            trace!(target: "service", "Rate limiting message from {remote} ({})", peer.addr);
+
            return Ok(());
+
        }
        peer.last_active = self.clock;
        message.log(log::Level::Debug, remote, Link::Inbound);

added radicle-node/src/service/limitter.rs
@@ -0,0 +1,182 @@
+
use std::collections::HashMap;
+

+
use localtime::LocalTime;
+
use radicle::node::HostName;
+

+
/// Peer rate limitter.
+
///
+
/// Uses a token bucket algorithm, where each address starts with a certain amount of tokens,
+
/// and every request from that address consumes one token. Tokens refill at a predefined
+
/// rate. This mechanism allows for consistent request rates with potential bursts up to the
+
/// bucket's capacity.
+
#[derive(Debug, Default)]
+
pub struct RateLimiter {
+
    buckets: HashMap<HostName, TokenBucket>,
+
}
+

+
impl RateLimiter {
+
    /// Call this when the address has performed some rate-limited action.
+
    /// Returns whether the action is rate-limited or not.
+
    ///
+
    /// Supplying a different amount of tokens per address is useful if for eg. a peer
+
    /// is outbound vs. inbound.
+
    pub fn limit<T: AsTokens>(&mut self, addr: HostName, tokens: &T, now: LocalTime) -> bool {
+
        !self
+
            .buckets
+
            .entry(addr)
+
            .or_insert_with(|| TokenBucket::new(tokens.capacity(), tokens.rate(), now))
+
            .take(now)
+
    }
+
}
+

+
/// Any type that can be assigned a number of rate-limit tokens.
+
pub trait AsTokens {
+
    /// Get the token capacity for this object.
+
    fn capacity(&self) -> usize;
+
    /// Get the refill rate for this object.
+
    /// A rate of `1.0` means one token per second.
+
    fn rate(&self) -> f64;
+
}
+

+
impl AsTokens for crate::Link {
+
    fn rate(&self) -> f64 {
+
        match self {
+
            Self::Inbound => 0.1,
+
            Self::Outbound => 1.0,
+
        }
+
    }
+

+
    fn capacity(&self) -> usize {
+
        match self {
+
            Self::Inbound => 16,
+
            Self::Outbound => 64,
+
        }
+
    }
+
}
+

+
#[derive(Debug)]
+
pub struct TokenBucket {
+
    /// Token refill rate per second.
+
    rate: f64,
+
    /// Token capacity.
+
    capacity: f64,
+
    /// Tokens remaining.
+
    tokens: f64,
+
    /// Time of last token refill.
+
    refilled_at: LocalTime,
+
}
+

+
impl TokenBucket {
+
    fn new(tokens: usize, rate: f64, now: LocalTime) -> Self {
+
        Self {
+
            rate,
+
            capacity: tokens as f64,
+
            tokens: tokens as f64,
+
            refilled_at: now,
+
        }
+
    }
+

+
    fn refill(&mut self, now: LocalTime) {
+
        let elapsed = now.duration_since(self.refilled_at);
+
        let tokens = elapsed.as_secs() as f64 * self.rate;
+

+
        self.tokens = (self.tokens + tokens).min(self.capacity);
+
        self.refilled_at = now;
+
    }
+

+
    fn take(&mut self, now: LocalTime) -> bool {
+
        self.refill(now);
+

+
        if self.tokens >= 1.0 {
+
            self.tokens -= 1.0;
+
            true
+
        } else {
+
            false
+
        }
+
    }
+
}
+

+
#[cfg(test)]
+
#[allow(clippy::bool_assert_comparison, clippy::redundant_clone)]
+
mod test {
+
    use radicle::node::Address;
+
    use radicle::test::arbitrary::gen;
+

+
    use super::*;
+

+
    impl AsTokens for (usize, f64) {
+
        fn capacity(&self) -> usize {
+
            self.0
+
        }
+

+
        fn rate(&self) -> f64 {
+
            self.1
+
        }
+
    }
+

+
    #[test]
+
    fn test_limitter_refill() {
+
        let mut r = RateLimiter::default();
+
        let t = (3, 0.2); // Three tokens burst. One token every 5 seconds.
+
        let a: HostName = gen::<Address>(1).into();
+

+
        assert_eq!(r.limit(a.clone(), &t, LocalTime::from_secs(0)), false); // Burst capacity
+
        assert_eq!(r.limit(a.clone(), &t, LocalTime::from_secs(1)), false); // Burst capacity
+
        assert_eq!(r.limit(a.clone(), &t, LocalTime::from_secs(2)), false); // Burst capacity
+
        assert_eq!(r.limit(a.clone(), &t, LocalTime::from_secs(3)), true); // Limited
+
        assert_eq!(r.limit(a.clone(), &t, LocalTime::from_secs(4)), true); // Limited
+
        assert_eq!(r.limit(a.clone(), &t, LocalTime::from_secs(5)), false); // Refilled (1)
+
        assert_eq!(r.limit(a.clone(), &t, LocalTime::from_secs(6)), true); // Limited
+
        assert_eq!(r.limit(a.clone(), &t, LocalTime::from_secs(7)), true); // Limited
+
        assert_eq!(r.limit(a.clone(), &t, LocalTime::from_secs(8)), true); // Limited
+
        assert_eq!(r.limit(a.clone(), &t, LocalTime::from_secs(9)), true); // Limited
+
        assert_eq!(r.limit(a.clone(), &t, LocalTime::from_secs(10)), false); // Refilled (1)
+
        assert_eq!(r.limit(a.clone(), &t, LocalTime::from_secs(11)), true); // Limited
+
        assert_eq!(r.limit(a.clone(), &t, LocalTime::from_secs(12)), true); // Limited
+
        assert_eq!(r.limit(a.clone(), &t, LocalTime::from_secs(13)), true); // Limited
+
        assert_eq!(r.limit(a.clone(), &t, LocalTime::from_secs(14)), true); // Limited
+
        assert_eq!(r.limit(a.clone(), &t, LocalTime::from_secs(15)), false); // Refilled (1)
+
        assert_eq!(r.limit(a.clone(), &t, LocalTime::from_secs(16)), true); // Limited
+
        assert_eq!(r.limit(a.clone(), &t, LocalTime::from_secs(60)), false); // Refilled (3)
+
        assert_eq!(r.limit(a.clone(), &t, LocalTime::from_secs(60)), false); // Burst capacity
+
        assert_eq!(r.limit(a.clone(), &t, LocalTime::from_secs(60)), false); // Burst capacity
+
        assert_eq!(r.limit(a.clone(), &t, LocalTime::from_secs(60)), true); // Limited
+
    }
+

+
    #[test]
+
    fn test_limitter_multi() {
+
        let t = (1, 1.0); // One token per second. One token burst.
+
        let mut r = RateLimiter::default();
+
        let addr1: HostName = gen::<Address>(1).into();
+
        let addr2: HostName = gen::<Address>(1).into();
+

+
        assert_eq!(r.limit(addr1.clone(), &t, LocalTime::from_secs(0)), false);
+
        assert_eq!(r.limit(addr1.clone(), &t, LocalTime::from_secs(0)), true);
+
        assert_eq!(r.limit(addr2.clone(), &t, LocalTime::from_secs(0)), false);
+
        assert_eq!(r.limit(addr2.clone(), &t, LocalTime::from_secs(0)), true);
+
        assert_eq!(r.limit(addr1.clone(), &t, LocalTime::from_secs(1)), false); // Refilled (1)
+
        assert_eq!(r.limit(addr1.clone(), &t, LocalTime::from_secs(1)), true);
+
        assert_eq!(r.limit(addr2.clone(), &t, LocalTime::from_secs(1)), false);
+
        assert_eq!(r.limit(addr2.clone(), &t, LocalTime::from_secs(1)), true);
+
    }
+

+
    #[test]
+
    fn test_limitter_different_rates() {
+
        let t1 = (1, 1.0); // One token per second. One token burst.
+
        let t2 = (2, 2.0); // Two tokens per second. Two token burst.
+
        let mut r = RateLimiter::default();
+
        let addr1: HostName = gen::<Address>(1).into();
+
        let addr2: HostName = gen::<Address>(1).into();
+

+
        assert_eq!(r.limit(addr1.clone(), &t1, LocalTime::from_secs(0)), false);
+
        assert_eq!(r.limit(addr1.clone(), &t1, LocalTime::from_secs(0)), true);
+
        assert_eq!(r.limit(addr2.clone(), &t2, LocalTime::from_secs(0)), false);
+
        assert_eq!(r.limit(addr2.clone(), &t2, LocalTime::from_secs(0)), false);
+
        assert_eq!(r.limit(addr2.clone(), &t2, LocalTime::from_secs(0)), true);
+
        assert_eq!(r.limit(addr1.clone(), &t1, LocalTime::from_secs(1)), false); // Refilled (1)
+
        assert_eq!(r.limit(addr1.clone(), &t1, LocalTime::from_secs(1)), true);
+
        assert_eq!(r.limit(addr2.clone(), &t2, LocalTime::from_secs(1)), false); // Refilled (2)
+
        assert_eq!(r.limit(addr2.clone(), &t2, LocalTime::from_secs(1)), false);
+
        assert_eq!(r.limit(addr2.clone(), &t2, LocalTime::from_secs(1)), true);
+
    }
+
}
modified radicle-node/src/wire/protocol.rs
@@ -466,21 +466,26 @@ where

    fn handle_listener_event(
        &mut self,
-
        socket_addr: net::SocketAddr,
+
        _sock: net::SocketAddr,
        event: ListenerEvent<WireSession<G>>,
        _: Timestamp,
    ) {
        match event {
            ListenerEvent::Accepted(connection) => {
-
                log::debug!(
-
                    target: "wire",
-
                    "Accepting inbound peer connection from {}..",
-
                    connection.remote_addr()
-
                );
-
                self.peers.insert(
-
                    connection.as_raw_fd(),
-
                    Peer::inbound(connection.remote_addr().into()),
-
                );
+
                let addr = connection.remote_addr();
+
                log::debug!(target: "wire", "Accepting inbound peer connection from {addr}..");
+

+
                self.peers
+
                    .insert(connection.as_raw_fd(), Peer::inbound(addr.clone().into()));
+

+
                // If the service doesn't want to accept this connection,
+
                // we drop the connection here, which disconnects the socket.
+
                if !self.service.accepted(NetAddr::from(addr.clone()).into()) {
+
                    log::debug!(target: "wire", "Dropping inbound connection from {addr}..");
+
                    drop(connection);
+

+
                    return;
+
                }

                let session = accept::<G>(connection, self.signer.clone());
                let transport = match NetTransport::with_session(session, Link::Inbound) {
@@ -490,7 +495,6 @@ where
                        return;
                    }
                };
-
                self.service.accepted(socket_addr);
                self.actions
                    .push_back(reactor::Action::RegisterTransport(transport))
            }
modified radicle/src/node.rs
@@ -15,7 +15,7 @@ use std::str::FromStr;
use std::{fmt, io, net, thread, time};

use amplify::WrapperMut;
-
use cyphernet::addr::{HostName, NetAddr};
+
use cyphernet::addr::NetAddr;
use localtime::LocalTime;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
@@ -27,7 +27,7 @@ use crate::storage::RefUpdate;

pub use address::KnownAddress;
pub use config::Config;
-
pub use cyphernet::addr::PeerAddr;
+
pub use cyphernet::addr::{HostName, PeerAddr};
pub use events::{Event, Events};
pub use features::Features;

@@ -261,7 +261,7 @@ impl From<CommandResult> for Result<bool, Error> {
pub struct Address(#[serde(with = "crate::serde_ext::string")] NetAddr<HostName>);

impl Address {
-
    /// Check whether this address is local.
+
    /// Check whether this address is from the local network.
    pub fn is_local(&self) -> bool {
        match self.0.host {
            HostName::Ip(ip) => address::is_local(&ip),
@@ -269,6 +269,15 @@ impl Address {
        }
    }

+
    /// Check whether this address is trusted.
+
    /// Returns true if the address is 127.0.0.1 or 0.0.0.0.
+
    pub fn is_trusted(&self) -> bool {
+
        match self.0.host {
+
            HostName::Ip(ip) => ip.is_loopback() || ip.is_unspecified(),
+
            _ => false,
+
        }
+
    }
+

    /// Check whether this address is globally routable.
    pub fn is_routable(&self) -> bool {
        match self.0.host {
@@ -299,6 +308,12 @@ impl From<net::SocketAddr> for Address {
    }
}

+
impl From<Address> for HostName {
+
    fn from(addr: Address) -> Self {
+
        addr.0.host
+
    }
+
}
+

/// Command name.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase", tag = "type")]