Radish alpha
h
Radicle Heartwood Protocol & Stack
Radicle
Git (anonymous pull)
Log in to clone via SSH
Use transaction to insert inventories
Alexis Sellier committed 2 years ago
commit 6b6f087c14a80cc7becbfb6ef5e953c5558a4182
parent b27313bc5c703c8a25343176c620c8652b815abc
3 files changed +139 -87
modified radicle-node/src/service.rs
@@ -333,6 +333,11 @@ where
        &mut self.addresses
    }

+
    /// Get the routing store.
+
    pub fn routing(&self) -> &R {
+
        &self.routing
+
    }
+

    /// Get the storage instance.
    pub fn storage(&self) -> &S {
        &self.storage
@@ -386,9 +391,11 @@ where
        // Ensure that our inventory is recorded in our routing table, and we are tracking
        // all of it. It can happen that inventory is not properly tracked if for eg. the
        // user creates a new repository while the node is stopped.
-
        for rid in self.storage.inventory()? {
-
            self.routing.insert(rid, self.node_id(), time.as_millis())?;
+
        let rids = self.storage.inventory()?;
+
        self.routing
+
            .insert(&rids, self.node_id(), time.as_millis())?;

+
        for rid in rids {
            if !self.is_tracking(&rid)? {
                if self
                    .track_repo(&rid, tracking::Scope::Trusted)
@@ -913,11 +920,11 @@ where

                // We update inventories when receiving ref announcements, as these could come
                // from a new repository being initialized.
-
                if let Ok(result) = self
-
                    .routing
-
                    .insert(message.rid, *announcer, message.timestamp)
+
                if let Ok(result) =
+
                    self.routing
+
                        .insert([&message.rid], *announcer, message.timestamp)
                {
-
                    if let InsertResult::SeedAdded = result {
+
                    if let &[(_, InsertResult::SeedAdded)] = result.as_slice() {
                        self.emitter.emit(Event::SeedDiscovered {
                            rid: message.rid,
                            nid: *relayer,
@@ -1182,28 +1189,24 @@ where
        timestamp: Timestamp,
    ) -> Result<SyncedRouting, Error> {
        let mut synced = SyncedRouting::default();
-
        let mut included = HashSet::new();
+
        let included: HashSet<&Id> = HashSet::from_iter(inventory);

-
        for rid in inventory {
-
            included.insert(rid);
-
            match self.routing.insert(*rid, from, timestamp)? {
+
        for (rid, result) in self.routing.insert(inventory, from, timestamp)? {
+
            match result {
                InsertResult::SeedAdded => {
                    info!(target: "service", "Routing table updated for {rid} with seed {from}");
-
                    self.emitter.emit(Event::SeedDiscovered {
-
                        rid: *rid,
-
                        nid: from,
-
                    });
+
                    self.emitter.emit(Event::SeedDiscovered { rid, nid: from });

-
                    if self.tracking.is_repo_tracked(rid).expect(
+
                    if self.tracking.is_repo_tracked(&rid).expect(
                        "Service::process_inventory: error accessing tracking configuration",
                    ) {
                        // TODO: We should fetch here if we're already connected, case this seed has
                        // refs we don't have.
                    }
-
                    synced.added.push(*rid);
+
                    synced.added.push(rid);
                }
                InsertResult::TimeUpdated => {
-
                    synced.updated.push(*rid);
+
                    synced.updated.push(rid);
                }
                InsertResult::NotUpdated => {}
            }
@@ -1539,8 +1542,6 @@ pub trait ServiceState {
    fn clock_mut(&mut self) -> &mut LocalTime;
    /// Get service configuration.
    fn config(&self) -> &Config;
-
    /// Get reference to routing table.
-
    fn routing(&self) -> &dyn routing::Store;
}

impl<R, A, S, G> ServiceState for Service<R, A, S, G>
@@ -1572,10 +1573,6 @@ where
    fn config(&self) -> &Config {
        &self.config
    }
-

-
    fn routing(&self) -> &dyn routing::Store {
-
        &self.routing
-
    }
}

/// Disconnect reason.
modified radicle-node/src/tests.rs
@@ -8,6 +8,7 @@ use std::time;

use crossbeam_channel as chan;
use netservices::Direction as Link;
+
use radicle::node::routing::Store as _;
use radicle::storage::ReadRepository;

use crate::collections::{HashMap, HashSet};
modified radicle/src/node/routing.rs
@@ -95,7 +95,12 @@ pub trait Store {
        Ok(self.len()? == 0)
    }
    /// Add a new node seeding the given id.
-
    fn insert(&mut self, id: Id, node: NodeId, time: Timestamp) -> Result<InsertResult, Error>;
+
    fn insert<'a>(
+
        &mut self,
+
        ids: impl IntoIterator<Item = &'a Id>,
+
        node: NodeId,
+
        time: Timestamp,
+
    ) -> Result<Vec<(Id, InsertResult)>, Error>;
    /// Remove a node for the given id.
    fn remove(&mut self, id: &Id, node: &NodeId) -> Result<bool, Error>;
    /// Iterate over all entries in the routing table.
@@ -149,35 +154,45 @@ impl Store for Table {
        Ok(None)
    }

-
    fn insert(&mut self, id: Id, node: NodeId, time: Timestamp) -> Result<InsertResult, Error> {
+
    fn insert<'a>(
+
        &mut self,
+
        ids: impl IntoIterator<Item = &'a Id>,
+
        node: NodeId,
+
        time: Timestamp,
+
    ) -> Result<Vec<(Id, InsertResult)>, Error> {
        let time: i64 = time.try_into().map_err(|_| Error::UnitOverflow)?;
+
        let mut results = Vec::new();

        transaction(&self.db, |db| {
-
            let mut stmt =
-
                db.prepare("SELECT (time) FROM routing WHERE resource = ? AND node = ?")?;
-

-
            stmt.bind((1, &id))?;
-
            stmt.bind((2, &node))?;
-

-
            let existed = stmt.into_iter().next().is_some();
-
            let mut stmt = db.prepare(
-
                "INSERT INTO routing (resource, node, time)
-
                 VALUES (?, ?, ?)
-
                 ON CONFLICT DO UPDATE
-
                 SET time = ?3
-
                 WHERE time < ?3",
-
            )?;
-

-
            stmt.bind((1, &id))?;
-
            stmt.bind((2, &node))?;
-
            stmt.bind((3, time))?;
-
            stmt.next()?;
-

-
            Ok(match (self.db.change_count() > 0, existed) {
-
                (true, true) => InsertResult::TimeUpdated,
-
                (true, false) => InsertResult::SeedAdded,
-
                (false, _) => InsertResult::NotUpdated,
-
            })
+
            for id in ids.into_iter() {
+
                let mut stmt =
+
                    db.prepare("SELECT (time) FROM routing WHERE resource = ? AND node = ?")?;
+

+
                stmt.bind((1, id))?;
+
                stmt.bind((2, &node))?;
+

+
                let existed = stmt.into_iter().next().is_some();
+
                let mut stmt = db.prepare(
+
                    "INSERT INTO routing (resource, node, time)
+
                     VALUES (?, ?, ?)
+
                     ON CONFLICT DO UPDATE
+
                     SET time = ?3
+
                     WHERE time < ?3",
+
                )?;
+

+
                stmt.bind((1, id))?;
+
                stmt.bind((2, &node))?;
+
                stmt.bind((3, time))?;
+
                stmt.next()?;
+

+
                let result = match (self.db.change_count() > 0, existed) {
+
                    (true, true) => InsertResult::TimeUpdated,
+
                    (true, false) => InsertResult::SeedAdded,
+
                    (false, _) => InsertResult::NotUpdated,
+
                };
+
                results.push((*id, result));
+
            }
+
            Ok(results)
        })
        .map_err(Error::from)
    }
@@ -271,10 +286,13 @@ mod test {
        let nodes = arbitrary::set::<NodeId>(5..10);
        let mut db = Table::open(":memory:").unwrap();

-
        for id in &ids {
-
            for node in &nodes {
-
                assert_eq!(db.insert(*id, *node, 0).unwrap(), InsertResult::SeedAdded);
-
            }
+
        for node in &nodes {
+
            assert_eq!(
+
                db.insert(&ids, *node, 0).unwrap(),
+
                ids.iter()
+
                    .map(|id| (*id, InsertResult::SeedAdded))
+
                    .collect::<Vec<_>>()
+
            );
        }

        for id in &ids {
@@ -291,10 +309,8 @@ mod test {
        let nodes = arbitrary::set::<NodeId>(5..10);
        let mut db = Table::open(":memory:").unwrap();

-
        for id in &ids {
-
            for node in &nodes {
-
                assert_eq!(db.insert(*id, *node, 0).unwrap(), InsertResult::SeedAdded);
-
            }
+
        for node in &nodes {
+
            db.insert(&ids, *node, 0).unwrap();
        }

        for node in &nodes {
@@ -311,10 +327,12 @@ mod test {
        let nodes = arbitrary::set::<NodeId>(6..9);
        let mut db = Table::open(":memory:").unwrap();

-
        for id in &ids {
-
            for node in &nodes {
-
                assert_eq!(db.insert(*id, *node, 0).unwrap(), InsertResult::SeedAdded);
-
            }
+
        for node in &nodes {
+
            assert!(db
+
                .insert(&ids, *node, 0)
+
                .unwrap()
+
                .iter()
+
                .all(|(_, r)| *r == InsertResult::SeedAdded));
        }

        let results = db.entries().unwrap().collect::<Vec<_>>();
@@ -332,10 +350,8 @@ mod test {
        let nodes = arbitrary::set::<NodeId>(5..10);
        let mut db = Table::open(":memory:").unwrap();

-
        for id in &ids {
-
            for node in &nodes {
-
                db.insert(*id, *node, 0).unwrap();
-
            }
+
        for node in &nodes {
+
            db.insert(&ids, *node, 0).unwrap();
        }
        for id in &ids {
            for node in &nodes {
@@ -353,9 +369,18 @@ mod test {
        let node = arbitrary::gen::<NodeId>(1);
        let mut db = Table::open(":memory:").unwrap();

-
        assert_eq!(db.insert(id, node, 0).unwrap(), InsertResult::SeedAdded);
-
        assert_eq!(db.insert(id, node, 0).unwrap(), InsertResult::NotUpdated);
-
        assert_eq!(db.insert(id, node, 0).unwrap(), InsertResult::NotUpdated);
+
        assert_eq!(
+
            db.insert([&id], node, 0).unwrap(),
+
            vec![(id, InsertResult::SeedAdded)]
+
        );
+
        assert_eq!(
+
            db.insert([&id], node, 0).unwrap(),
+
            vec![(id, InsertResult::NotUpdated)]
+
        );
+
        assert_eq!(
+
            db.insert([&id], node, 0).unwrap(),
+
            vec![(id, InsertResult::NotUpdated)]
+
        );
    }

    #[test]
@@ -364,18 +389,54 @@ mod test {
        let node = arbitrary::gen::<NodeId>(1);
        let mut db = Table::open(":memory:").unwrap();

-
        assert_eq!(db.insert(id, node, 0).unwrap(), InsertResult::SeedAdded);
-
        assert_eq!(db.insert(id, node, 1).unwrap(), InsertResult::TimeUpdated);
+
        assert_eq!(
+
            db.insert([&id], node, 0).unwrap(),
+
            vec![(id, InsertResult::SeedAdded)]
+
        );
+
        assert_eq!(
+
            db.insert([&id], node, 1).unwrap(),
+
            vec![(id, InsertResult::TimeUpdated)]
+
        );
        assert_eq!(db.entry(&id, &node).unwrap(), Some(1));
    }

    #[test]
+
    fn test_update_existing_multi() {
+
        let id1 = arbitrary::gen::<Id>(1);
+
        let id2 = arbitrary::gen::<Id>(1);
+
        let node = arbitrary::gen::<NodeId>(1);
+
        let mut db = Table::open(":memory:").unwrap();
+

+
        assert_eq!(
+
            db.insert([&id1], node, 0).unwrap(),
+
            vec![(id1, InsertResult::SeedAdded)]
+
        );
+
        assert_eq!(
+
            db.insert([&id1, &id2], node, 0).unwrap(),
+
            vec![
+
                (id1, InsertResult::NotUpdated),
+
                (id2, InsertResult::SeedAdded)
+
            ]
+
        );
+
        assert_eq!(
+
            db.insert([&id1, &id2], node, 1).unwrap(),
+
            vec![
+
                (id1, InsertResult::TimeUpdated),
+
                (id2, InsertResult::TimeUpdated)
+
            ]
+
        );
+
    }
+

+
    #[test]
    fn test_remove_redundant() {
        let id = arbitrary::gen::<Id>(1);
        let node = arbitrary::gen::<NodeId>(1);
        let mut db = Table::open(":memory:").unwrap();

-
        assert_eq!(db.insert(id, node, 0).unwrap(), InsertResult::SeedAdded);
+
        assert_eq!(
+
            db.insert([&id], node, 0).unwrap(),
+
            vec![(id, InsertResult::SeedAdded)]
+
        );
        assert!(db.remove(&id, &node).unwrap());
        assert!(!db.remove(&id, &node).unwrap());
    }
@@ -386,9 +447,7 @@ mod test {
        let ids = arbitrary::vec::<Id>(10);
        let node = arbitrary::gen(1);

-
        for id in ids {
-
            db.insert(id, node, LocalTime::now().as_millis()).unwrap();
-
        }
+
        db.insert(&ids, node, LocalTime::now().as_millis()).unwrap();

        assert_eq!(10, db.len().unwrap(), "correct number of rows in table");
    }
@@ -401,21 +460,17 @@ mod test {
        let nodes = arbitrary::vec::<NodeId>(10);
        let mut db = Table::open(":memory:").unwrap();

-
        for id in &ids {
-
            for node in &nodes {
-
                let time = rng.u64(..now.as_millis());
-
                db.insert(*id, *node, time).unwrap();
-
            }
+
        for node in &nodes {
+
            let time = rng.u64(..now.as_millis());
+
            db.insert(&ids, *node, time).unwrap();
        }

        let ids = arbitrary::vec::<Id>(10);
        let nodes = arbitrary::vec::<NodeId>(10);

-
        for id in &ids {
-
            for node in &nodes {
-
                let time = rng.u64(now.as_millis()..i64::MAX as u64);
-
                db.insert(*id, *node, time).unwrap();
-
            }
+
        for node in &nodes {
+
            let time = rng.u64(now.as_millis()..i64::MAX as u64);
+
            db.insert(&ids, *node, time).unwrap();
        }

        let pruned = db.prune(now.as_millis(), None).unwrap();
@@ -436,9 +491,8 @@ mod test {
        let mut db = Table::open(":memory:").unwrap();

        for node in &nodes {
-
            db.insert(id, *node, 0).unwrap();
+
            db.insert([&id], *node, 0).unwrap();
        }
-

        assert_eq!(db.count(&id).unwrap(), nodes.len());
    }
}