Radish alpha
h
rad:z3gqcJUoA1n9HaHKufZs5FCSGazv5
Radicle Heartwood Protocol & Stack
Radicle
Git
heartwood crates radicle src node routing.rs
use std::collections::HashSet;

use sqlite as sql;
use thiserror::Error;

use crate::node::Database;
use crate::{
    prelude::Timestamp,
    prelude::{NodeId, RepoId},
    sql::transaction,
};

/// Result of inserting into the routing table.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum InsertResult {
    /// Nothing was updated.
    NotUpdated,
    /// The entry's timestamp was updated.
    TimeUpdated,
    /// A new entry was inserted.
    SeedAdded,
}

/// An error occurring in peer-to-peer networking code.
#[derive(Error, Debug)]
pub enum Error {
    /// An Internal error.
    #[error("internal error: {0}")]
    Internal(#[from] sql::Error),
    /// Internal unit overflow.
    #[error("the unit overflowed")]
    UnitOverflow,
}

/// Backing store for a routing table.
pub trait Store {
    /// Get the nodes seeding the given id.
    fn get(&self, id: &RepoId) -> Result<HashSet<NodeId>, Error>;
    /// Get the inventory seeded by the given node.
    fn get_inventory(&self, node_id: &NodeId) -> Result<HashSet<RepoId>, Error>;
    /// Get a specific entry.
    fn entry(&self, id: &RepoId, node: &NodeId) -> Result<Option<Timestamp>, Error>;
    /// Checks if any entries are available.
    fn is_empty(&self) -> Result<bool, Error> {
        Ok(self.len()? == 0)
    }
    /// Add a new node seeding the given id.
    fn add_inventory<'a>(
        &mut self,
        ids: impl IntoIterator<Item = &'a RepoId>,
        node: NodeId,
        time: Timestamp,
    ) -> Result<Vec<(RepoId, InsertResult)>, Error>;
    /// Remove an inventory from the given node.
    fn remove_inventory(&mut self, id: &RepoId, node: &NodeId) -> Result<bool, Error>;
    /// Remove multiple inventories from the given node.
    fn remove_inventories<'a>(
        &mut self,
        ids: impl IntoIterator<Item = &'a RepoId>,
        node: &NodeId,
    ) -> Result<(), Error>;
    /// Iterate over all entries in the routing table.
    fn entries(&self) -> Result<Box<dyn Iterator<Item = (RepoId, NodeId)>>, Error>;
    /// Get the total number of routing entries.
    fn len(&self) -> Result<usize, Error>;
    /// Prune entries older than the given timestamp.
    fn prune(
        &mut self,
        oldest: Timestamp,
        limit: Option<usize>,
        ignore: &NodeId,
    ) -> Result<usize, Error>;
    /// Count the number of routes for a specific repo RID.
    fn count(&self, id: &RepoId) -> Result<usize, Error>;
}

impl Store for Database {
    fn get(&self, id: &RepoId) -> Result<HashSet<NodeId>, Error> {
        let mut stmt = self
            .db
            .prepare("SELECT (node) FROM routing WHERE repo = ?")?;
        stmt.bind((1, id))?;

        let mut nodes = HashSet::new();
        for row in stmt.into_iter() {
            nodes.insert(row?.read::<NodeId, _>("node"));
        }
        Ok(nodes)
    }

    fn get_inventory(&self, node: &NodeId) -> Result<HashSet<RepoId>, Error> {
        let mut stmt = self.db.prepare("SELECT repo FROM routing WHERE node = ?")?;
        stmt.bind((1, node))?;

        let mut inventory = HashSet::new();
        for row in stmt.into_iter() {
            inventory.insert(row?.read::<RepoId, _>("repo"));
        }
        Ok(inventory)
    }

    fn entry(&self, id: &RepoId, node: &NodeId) -> Result<Option<Timestamp>, Error> {
        let mut stmt = self
            .db
            .prepare("SELECT (timestamp) FROM routing WHERE repo = ? AND node = ?")?;

        stmt.bind((1, id))?;
        stmt.bind((2, node))?;

        if let Some(Ok(row)) = stmt.into_iter().next() {
            return Ok(Some(row.try_read::<Timestamp, _>("timestamp")?));
        }
        Ok(None)
    }

    fn add_inventory<'a>(
        &mut self,
        ids: impl IntoIterator<Item = &'a RepoId>,
        node: NodeId,
        time: Timestamp,
    ) -> Result<Vec<(RepoId, InsertResult)>, Error> {
        let mut results = Vec::new();
        let mut select_stmt = self
            .db
            .prepare("SELECT (timestamp) FROM routing WHERE repo = ? AND node = ?")?;
        let mut insert_stmt = self.db.prepare(
            "INSERT INTO routing (repo, node, timestamp)
             VALUES (?, ?, ?)
             ON CONFLICT DO UPDATE
             SET timestamp = ?3
             WHERE timestamp < ?3",
        )?;
        transaction(&self.db, |_| {
            for id in ids.into_iter() {
                select_stmt.bind((1, id))?;
                select_stmt.bind((2, &node))?;

                let existed = select_stmt.iter().next().is_some();
                select_stmt.reset()?;

                insert_stmt.bind((1, id))?;
                insert_stmt.bind((2, &node))?;
                insert_stmt.bind((3, &time))?;
                insert_stmt.next()?;
                insert_stmt.reset()?;

                let result = match (self.db.change_count() > 0, existed) {
                    (true, true) => InsertResult::TimeUpdated,
                    (true, false) => InsertResult::SeedAdded,
                    (false, _) => InsertResult::NotUpdated,
                };
                results.push((*id, result));
            }
            Ok::<_, Error>(results)
        })
    }

    fn entries(&self) -> Result<Box<dyn Iterator<Item = (RepoId, NodeId)>>, Error> {
        let mut stmt = self
            .db
            .prepare("SELECT repo, node FROM routing ORDER BY repo")?
            .into_iter();
        let mut entries = Vec::new();

        while let Some(Ok(row)) = stmt.next() {
            let id = row.read("repo");
            let node = row.read("node");

            entries.push((id, node));
        }
        Ok(Box::new(entries.into_iter()))
    }

    fn remove_inventory(&mut self, id: &RepoId, node: &NodeId) -> Result<bool, Error> {
        let mut stmt = self
            .db
            .prepare("DELETE FROM routing WHERE repo = ? AND node = ?")?;

        stmt.bind((1, id))?;
        stmt.bind((2, node))?;
        stmt.next()?;

        Ok(self.db.change_count() > 0)
    }

    fn remove_inventories<'a>(
        &mut self,
        rids: impl IntoIterator<Item = &'a RepoId>,
        nid: &NodeId,
    ) -> Result<(), Error> {
        let mut stmt = self
            .db
            .prepare("DELETE FROM routing WHERE repo = ? AND node = ?")?;

        transaction(&self.db, |_| {
            for rid in rids.into_iter() {
                stmt.bind((1, rid))?;
                stmt.bind((2, nid))?;

                stmt.iter().next();
                stmt.reset()?;
            }
            Ok::<_, Error>(())
        })
    }

    fn len(&self) -> Result<usize, Error> {
        let stmt = self.db.prepare("SELECT COUNT(1) FROM routing")?;
        let count: i64 = stmt
            .into_iter()
            .next()
            .expect("COUNT will always return a single row")?
            .read(0);
        let count: usize = count.try_into().map_err(|_| Error::UnitOverflow)?;
        Ok(count)
    }

    fn prune(
        &mut self,
        oldest: Timestamp,
        limit: Option<usize>,
        ignore: &NodeId,
    ) -> Result<usize, Error> {
        let limit: i64 = limit
            .and_then(|limit| i64::try_from(limit).ok())
            .unwrap_or(i64::MAX);
        let mut stmt = self.db.prepare(
            "DELETE FROM routing
             WHERE node <> ?1 AND rowid IN
             (SELECT rowid FROM routing WHERE timestamp < ?2 ORDER BY timestamp LIMIT ?3)",
        )?;
        stmt.bind((1, ignore))?;
        stmt.bind((2, &oldest))?;
        stmt.bind((3, limit))?;
        stmt.next()?;

        Ok(self.db.change_count())
    }

    fn count(&self, id: &RepoId) -> Result<usize, Error> {
        let mut stmt = self
            .db
            .prepare("SELECT COUNT(*) FROM routing WHERE repo = ?")?;

        stmt.bind((1, id))?;

        let count: i64 = stmt
            .into_iter()
            .next()
            .expect("COUNT will always return a single row")?
            .read(0);

        let count: usize = count.try_into().map_err(|_| Error::UnitOverflow)?;

        Ok(count)
    }
}

#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod test {
    use localtime::LocalTime;

    use super::*;
    use crate::test::arbitrary;

    fn database(path: &str) -> Database {
        let db = Database::open(path, crate::node::db::config::Config::default()).unwrap();

        // We don't want to test foreign key constraints here.
        db.db.execute("PRAGMA foreign_keys = OFF").unwrap();
        db
    }

    #[test]
    fn test_insert_and_get() {
        let ids = arbitrary::set::<RepoId>(5..10);
        let nodes = arbitrary::set::<NodeId>(5..10);
        let mut db = database(":memory:");

        for node in &nodes {
            assert_eq!(
                db.add_inventory(&ids, *node, Timestamp::EPOCH).unwrap(),
                ids.iter()
                    .map(|id| (*id, InsertResult::SeedAdded))
                    .collect::<Vec<_>>()
            );
        }

        for id in &ids {
            let seeds = db.get(id).unwrap();
            for node in &nodes {
                assert!(seeds.contains(node));
            }
        }
    }

    #[test]
    fn test_insert_and_get_resources() {
        let ids = arbitrary::set::<RepoId>(5..10);
        let nodes = arbitrary::set::<NodeId>(5..10);
        let mut db = database(":memory:");

        for node in &nodes {
            db.add_inventory(&ids, *node, Timestamp::EPOCH).unwrap();
        }

        for node in &nodes {
            let projects = db.get_inventory(node).unwrap();
            for id in &ids {
                assert!(projects.contains(id));
            }
        }
    }

    #[test]
    fn test_entries() {
        let ids = arbitrary::set::<RepoId>(6..9);
        let nodes = arbitrary::set::<NodeId>(6..9);
        let mut db = database(":memory:");

        for node in &nodes {
            assert!(
                db.add_inventory(&ids, *node, Timestamp::EPOCH)
                    .unwrap()
                    .iter()
                    .all(|(_, r)| *r == InsertResult::SeedAdded)
            );
        }

        let results = db.entries().unwrap().collect::<Vec<_>>();
        assert_eq!(results.len(), ids.len() * nodes.len());

        let mut results_ids = results.iter().map(|(id, _)| *id).collect::<Vec<_>>();
        results_ids.dedup();

        assert_eq!(results_ids.len(), ids.len(), "Entries are grouped by id");
    }

    #[test]
    fn test_insert_and_remove() {
        let ids = arbitrary::set::<RepoId>(5..10);
        let nodes = arbitrary::set::<NodeId>(5..10);
        let mut db = database(":memory:");

        for node in &nodes {
            db.add_inventory(&ids, *node, Timestamp::EPOCH).unwrap();
        }
        for id in &ids {
            for node in &nodes {
                assert!(db.remove_inventory(id, node).unwrap());
            }
        }
        for id in &ids {
            assert!(db.get(id).unwrap().is_empty());
        }
    }

    #[test]
    fn test_insert_duplicate() {
        let id = arbitrary::r#gen::<RepoId>(1);
        let node = arbitrary::r#gen::<NodeId>(1);
        let mut db = database(":memory:");

        assert_eq!(
            db.add_inventory([&id], node, Timestamp::EPOCH).unwrap(),
            vec![(id, InsertResult::SeedAdded)]
        );
        assert_eq!(
            db.add_inventory([&id], node, Timestamp::EPOCH).unwrap(),
            vec![(id, InsertResult::NotUpdated)]
        );
        assert_eq!(
            db.add_inventory([&id], node, Timestamp::EPOCH).unwrap(),
            vec![(id, InsertResult::NotUpdated)]
        );
    }

    #[test]
    fn test_insert_existing_updated_time() {
        let id = arbitrary::r#gen::<RepoId>(1);
        let node = arbitrary::r#gen::<NodeId>(1);
        let mut db = database(":memory:");

        assert_eq!(
            db.add_inventory([&id], node, Timestamp::EPOCH).unwrap(),
            vec![(id, InsertResult::SeedAdded)]
        );
        assert_eq!(
            db.add_inventory([&id], node, Timestamp::try_from(1u64).unwrap())
                .unwrap(),
            vec![(id, InsertResult::TimeUpdated)]
        );
        assert_eq!(
            db.entry(&id, &node).unwrap(),
            Some(Timestamp::try_from(1u64).unwrap())
        );
    }

    #[test]
    fn test_update_existing_multi() {
        let id1 = arbitrary::r#gen::<RepoId>(1);
        let id2 = arbitrary::r#gen::<RepoId>(1);
        let node = arbitrary::r#gen::<NodeId>(1);
        let mut db = database(":memory:");

        assert_eq!(
            db.add_inventory([&id1], node, Timestamp::EPOCH).unwrap(),
            vec![(id1, InsertResult::SeedAdded)]
        );
        assert_eq!(
            db.add_inventory([&id1, &id2], node, Timestamp::EPOCH)
                .unwrap(),
            vec![
                (id1, InsertResult::NotUpdated),
                (id2, InsertResult::SeedAdded)
            ]
        );
        assert_eq!(
            db.add_inventory([&id1, &id2], node, Timestamp::try_from(1u64).unwrap())
                .unwrap(),
            vec![
                (id1, InsertResult::TimeUpdated),
                (id2, InsertResult::TimeUpdated)
            ]
        );
    }

    #[test]
    fn test_remove_redundant() {
        let id = arbitrary::r#gen::<RepoId>(1);
        let node = arbitrary::r#gen::<NodeId>(1);
        let mut db = database(":memory:");

        assert_eq!(
            db.add_inventory([&id], node, Timestamp::EPOCH).unwrap(),
            vec![(id, InsertResult::SeedAdded)]
        );
        assert!(db.remove_inventory(&id, &node).unwrap());
        assert!(!db.remove_inventory(&id, &node).unwrap());
    }

    #[test]
    fn test_remove_many() {
        let id1 = arbitrary::r#gen::<RepoId>(1);
        let id2 = arbitrary::r#gen::<RepoId>(1);
        let id3 = arbitrary::r#gen::<RepoId>(1);
        let node = arbitrary::r#gen::<NodeId>(1);
        let mut db = database(":memory:");

        db.add_inventory([&id1, &id2, &id3], node, Timestamp::EPOCH)
            .unwrap();
        assert_eq!(db.len().unwrap(), 3);

        db.remove_inventories([&id1, &id3], &node).unwrap();
        assert_eq!(db.len().unwrap(), 1);
    }

    #[test]
    fn test_len() {
        let mut db = database(":memory:");
        let ids = arbitrary::vec::<RepoId>(10);
        let node = arbitrary::r#gen(1);

        db.add_inventory(&ids, node, LocalTime::now().into())
            .unwrap();

        assert_eq!(10, db.len().unwrap(), "correct number of rows in table");
    }

    #[test]
    fn test_prune() {
        let mut rng = fastrand::Rng::new();
        let now = LocalTime::now();
        let ids = arbitrary::vec::<RepoId>(10);
        let nodes = arbitrary::vec::<NodeId>(10);
        let mut db = database(":memory:");

        for node in &nodes {
            let time = rng.u64(..now.as_millis());
            db.add_inventory(&ids, *node, Timestamp::try_from(time).unwrap())
                .unwrap();
        }

        let ids = arbitrary::vec::<RepoId>(10);
        let nodes = arbitrary::vec::<NodeId>(10);

        for node in &nodes {
            let time = rng.u64(now.as_millis()..i64::MAX as u64);
            db.add_inventory(&ids, *node, Timestamp::try_from(time).unwrap())
                .unwrap();
        }

        let pruned = db.prune(now.into(), None, &arbitrary::r#gen(1)).unwrap();
        assert_eq!(pruned, ids.len() * nodes.len());

        for id in &ids {
            for node in &nodes {
                let t = db.entry(id, node).unwrap().unwrap();
                assert!(*t >= *Timestamp::from(now));
            }
        }
    }

    #[test]
    fn test_count() {
        let id = arbitrary::r#gen::<RepoId>(1);
        let nodes = arbitrary::set::<NodeId>(5..10);
        let mut db = database(":memory:");

        for node in &nodes {
            db.add_inventory([&id], *node, Timestamp::EPOCH).unwrap();
        }
        assert_eq!(db.count(&id).unwrap(), nodes.len());
    }
}