Radish alpha
h
rad:z3gqcJUoA1n9HaHKufZs5FCSGazv5
Radicle Heartwood Protocol & Stack
Radicle
Git
heartwood crates radicle-protocol src service gossip store.rs
use std::num::TryFromIntError;
use std::{fmt, io};

use radicle::crypto::Signature;
use sqlite as sql;
use thiserror::Error;

use crate::service::filter::Filter;
use crate::service::message::{
    Announcement, AnnouncementMessage, InventoryAnnouncement, NodeAnnouncement, RefsAnnouncement,
};
use crate::wire;
use crate::wire::{Decode as _, Encode as _};
use radicle::node::Database;
use radicle::node::NodeId;
use radicle::prelude::Timestamp;

#[derive(Error, Debug)]
pub enum Error {
    /// An Internal error.
    #[error("internal error: {0}")]
    Internal(#[from] sql::Error),
    /// Unit overflow.
    #[error("unit overflow:: {0}")]
    UnitOverflow(#[from] TryFromIntError),
}

/// Unique announcement identifier.
pub type AnnouncementId = u64;

/// A database that has access to historical gossip messages.
/// Keeps track of the latest received gossip messages for each node.
/// Grows linearly with the number of nodes on the network.
pub trait Store {
    /// Prune announcements older than the cutoff time.
    fn prune(&mut self, cutoff: Timestamp) -> Result<usize, Error>;

    /// Get the timestamp of the last announcement in the store.
    fn last(&self) -> Result<Option<Timestamp>, Error>;

    /// Process an announcement for the given node.
    /// Returns `true` if the timestamp was updated or the announcement wasn't there before.
    fn announced(
        &mut self,
        nid: &NodeId,
        ann: &Announcement,
    ) -> Result<Option<AnnouncementId>, Error>;

    /// Set whether a message should be relayed or not.
    fn set_relay(&mut self, id: AnnouncementId, relay: RelayStatus) -> Result<(), Error>;

    /// Return messages that should be relayed.
    fn relays(&mut self, now: Timestamp) -> Result<Vec<(AnnouncementId, Announcement)>, Error>;

    /// Get all the latest gossip messages of all nodes, filtered by inventory filter and
    /// announcement timestamps.
    ///
    /// # Panics
    ///
    /// Panics if `from` > `to`.
    ///
    fn filtered<'a>(
        &'a self,
        filter: &'a Filter,
        from: Timestamp,
        to: Timestamp,
    ) -> Result<Box<dyn Iterator<Item = Result<Announcement, Error>> + 'a>, Error>;
}

impl Store for Database {
    fn prune(&mut self, cutoff: Timestamp) -> Result<usize, Error> {
        let mut stmt = self
            .db
            .prepare("DELETE FROM `announcements` WHERE timestamp < ?1")?;

        stmt.bind((1, &cutoff))?;
        stmt.next()?;

        Ok(self.db.change_count())
    }

    fn last(&self) -> Result<Option<Timestamp>, Error> {
        let stmt = self
            .db
            .prepare("SELECT MAX(timestamp) AS latest FROM `announcements`")?;

        if let Some(Ok(row)) = stmt.into_iter().next() {
            return match row.try_read::<Option<i64>, _>(0)? {
                Some(i) => Ok(Some(Timestamp::try_from(i)?)),
                None => Ok(None),
            };
        }
        Ok(None)
    }

    fn announced(
        &mut self,
        nid: &NodeId,
        ann: &Announcement,
    ) -> Result<Option<AnnouncementId>, Error> {
        assert_ne!(
            ann.timestamp(),
            Timestamp::MIN,
            "Timestamp of {ann:?} must not be zero"
        );
        let mut stmt = self.db.prepare(
            "INSERT INTO `announcements` (node, repo, type, message, signature, timestamp)
             VALUES (?1, ?2, ?3, ?4, ?5, ?6)
             ON CONFLICT DO UPDATE
             SET message = ?4, signature = ?5, timestamp = ?6
             WHERE timestamp < ?6
             RETURNING rowid",
        )?;
        stmt.bind((1, nid))?;

        match &ann.message {
            AnnouncementMessage::Node(msg) => {
                stmt.bind((2, sql::Value::String(String::new())))?;
                stmt.bind((3, &GossipType::Node))?;
                stmt.bind((4, &msg.encode_to_vec()[..]))?;
            }
            AnnouncementMessage::Refs(msg) => {
                stmt.bind((2, &msg.rid))?;
                stmt.bind((3, &GossipType::Refs))?;
                stmt.bind((4, &msg.encode_to_vec()[..]))?;
            }
            AnnouncementMessage::Inventory(msg) => {
                stmt.bind((2, sql::Value::String(String::new())))?;
                stmt.bind((3, &GossipType::Inventory))?;
                stmt.bind((4, &msg.encode_to_vec()[..]))?;
            }
        }
        stmt.bind((5, &ann.signature))?;
        stmt.bind((6, &ann.message.timestamp()))?;

        if let Some(row) = stmt.into_iter().next() {
            let row = row?;
            let id = row.try_read::<i64, _>("rowid")?;

            Ok(Some(id as AnnouncementId))
        } else {
            Ok(None)
        }
    }

    fn set_relay(&mut self, id: AnnouncementId, relay: RelayStatus) -> Result<(), Error> {
        let mut stmt = self.db.prepare(
            "UPDATE announcements
             SET relay = ?1
             WHERE rowid = ?2",
        )?;
        stmt.bind((1, relay))?;
        stmt.bind((2, id as i64))?;
        stmt.next()?;

        Ok(())
    }

    fn relays(&mut self, now: Timestamp) -> Result<Vec<(AnnouncementId, Announcement)>, Error> {
        let mut stmt = self.db.prepare(
            "UPDATE announcements
             SET relay = ?1
             WHERE relay IS ?2
             RETURNING rowid, node, type, message, signature, timestamp",
        )?;
        stmt.bind((1, RelayStatus::RelayedAt(now)))?;
        stmt.bind((2, RelayStatus::Relay))?;

        let mut rows = stmt
            .into_iter()
            .map(|row| {
                let row = row?;
                parse::announcement(row)
            })
            .collect::<Result<Vec<_>, _>>()?;

        // Nb. Manually sort by insertion order, because we can't use `ORDER BY` with `RETURNING`
        // as of SQLite 3.45.
        rows.sort_by_key(|(id, _)| *id);

        Ok(rows)
    }

    fn filtered<'a>(
        &'a self,
        filter: &'a Filter,
        from: Timestamp,
        to: Timestamp,
    ) -> Result<Box<dyn Iterator<Item = Result<Announcement, Error>> + 'a>, Error> {
        let mut stmt = self.db.prepare(
            "SELECT rowid, node, type, message, signature, timestamp
             FROM announcements
             WHERE timestamp >= ?1 and timestamp < ?2
             ORDER BY timestamp, node, type",
        )?;
        assert!(*from <= *to);

        stmt.bind((1, &from))?;
        stmt.bind((2, &to))?;

        Ok(Box::new(
            stmt.into_iter()
                .map(|row| {
                    let row = row?;
                    let (_, ann) = parse::announcement(row)?;

                    Ok(ann)
                })
                .filter(|ann| match ann {
                    Ok(a) => a.matches(filter),
                    Err(_) => true,
                }),
        ))
    }
}

impl TryFrom<&sql::Value> for NodeAnnouncement {
    type Error = sql::Error;

    fn try_from(value: &sql::Value) -> Result<Self, Self::Error> {
        match value {
            sql::Value::Binary(bytes) => {
                let mut reader = io::Cursor::new(bytes);
                NodeAnnouncement::decode(&mut reader).map_err(wire::Error::into)
            }
            _ => Err(sql::Error {
                code: None,
                message: Some("sql: invalid type for node announcement".to_owned()),
            }),
        }
    }
}

impl TryFrom<&sql::Value> for RefsAnnouncement {
    type Error = sql::Error;

    fn try_from(value: &sql::Value) -> Result<Self, Self::Error> {
        match value {
            sql::Value::Binary(bytes) => {
                let mut reader = io::Cursor::new(bytes);
                RefsAnnouncement::decode(&mut reader).map_err(wire::Error::into)
            }
            _ => Err(sql::Error {
                code: None,
                message: Some("sql: invalid type for refs announcement".to_owned()),
            }),
        }
    }
}

impl TryFrom<&sql::Value> for InventoryAnnouncement {
    type Error = sql::Error;

    fn try_from(value: &sql::Value) -> Result<Self, Self::Error> {
        match value {
            sql::Value::Binary(bytes) => {
                let mut reader = io::Cursor::new(bytes);
                InventoryAnnouncement::decode(&mut reader).map_err(wire::Error::into)
            }
            _ => Err(sql::Error {
                code: None,
                message: Some("sql: invalid type for inventory announcement".to_owned()),
            }),
        }
    }
}

impl From<wire::Error> for sql::Error {
    fn from(other: wire::Error) -> Self {
        sql::Error {
            code: None,
            message: Some(other.to_string()),
        }
    }
}

/// Message relay status.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RelayStatus {
    Relay,
    DontRelay,
    RelayedAt(Timestamp),
}

impl sql::BindableWithIndex for RelayStatus {
    fn bind<I: sql::ParameterIndex>(self, stmt: &mut sql::Statement<'_>, i: I) -> sql::Result<()> {
        match self {
            Self::Relay => sql::Value::Null.bind(stmt, i),
            Self::DontRelay => sql::Value::Integer(-1).bind(stmt, i),
            Self::RelayedAt(t) => t.bind(stmt, i),
        }
    }
}

/// Type of gossip message.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum GossipType {
    Refs,
    Node,
    Inventory,
}

impl fmt::Display for GossipType {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            Self::Refs => write!(f, "refs"),
            Self::Node => write!(f, "node"),
            Self::Inventory => write!(f, "inventory"),
        }
    }
}

impl sql::BindableWithIndex for &GossipType {
    fn bind<I: sql::ParameterIndex>(self, stmt: &mut sql::Statement<'_>, i: I) -> sql::Result<()> {
        self.to_string().as_str().bind(stmt, i)
    }
}

impl TryFrom<&sql::Value> for GossipType {
    type Error = sql::Error;

    fn try_from(value: &sql::Value) -> Result<Self, Self::Error> {
        match value {
            sql::Value::String(s) => match s.as_str() {
                "refs" => Ok(Self::Refs),
                "node" => Ok(Self::Node),
                "inventory" => Ok(Self::Inventory),
                other => Err(sql::Error {
                    code: None,
                    message: Some(format!("unknown gossip type '{other}'")),
                }),
            },
            _ => Err(sql::Error {
                code: None,
                message: Some("sql: invalid type for gossip type".to_owned()),
            }),
        }
    }
}

mod parse {
    use super::*;

    pub fn announcement(row: sql::Row) -> Result<(AnnouncementId, Announcement), Error> {
        let id = row.try_read::<i64, _>("rowid")? as AnnouncementId;
        let node = row.try_read::<NodeId, _>("node")?;
        let gt = row.try_read::<GossipType, _>("type")?;
        let message = match gt {
            GossipType::Refs => {
                let ann = row.try_read::<RefsAnnouncement, _>("message")?;
                AnnouncementMessage::Refs(ann)
            }
            GossipType::Inventory => {
                let ann = row.try_read::<InventoryAnnouncement, _>("message")?;
                AnnouncementMessage::Inventory(ann)
            }
            GossipType::Node => {
                let ann = row.try_read::<NodeAnnouncement, _>("message")?;
                AnnouncementMessage::Node(ann)
            }
        };
        let signature = row.try_read::<Signature, _>("signature")?;
        let timestamp = row.try_read::<Timestamp, _>("timestamp")?;

        debug_assert_eq!(timestamp, message.timestamp());

        Ok((
            id,
            Announcement {
                node,
                message,
                signature,
            },
        ))
    }
}

#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod test {
    use super::*;
    use crate::bounded::BoundedVec;
    use localtime::LocalTime;
    use radicle::assert_matches;
    use radicle::identity::RepoId;
    use radicle::node::device::Device;
    use radicle::test::arbitrary;

    #[test]
    fn test_announced() {
        let mut db = Database::memory().unwrap();
        let nid = arbitrary::r#gen::<NodeId>(1);
        let rid = arbitrary::r#gen::<RepoId>(1);
        let timestamp = LocalTime::now().into();
        let signer = Device::mock();
        let refs = AnnouncementMessage::Refs(RefsAnnouncement {
            rid,
            refs: BoundedVec::new(),
            timestamp,
        })
        .signed(&signer);
        let inv = AnnouncementMessage::Inventory(InventoryAnnouncement {
            inventory: BoundedVec::new(),
            timestamp,
        })
        .signed(&signer);

        // Only the first announcement of each type is recognized as new.
        let id1 = db.announced(&nid, &refs).unwrap().unwrap();
        assert!(db.announced(&nid, &refs).unwrap().is_none());

        let id2 = db.announced(&nid, &inv).unwrap().unwrap();
        assert!(db.announced(&nid, &inv).unwrap().is_none());

        // Nothing was set to be relayed.
        assert_eq!(db.relays(LocalTime::now().into()).unwrap().len(), 0);

        // Set the messages to be relayed.
        db.set_relay(id1, RelayStatus::Relay).unwrap();
        db.set_relay(id2, RelayStatus::Relay).unwrap();

        // Now they are returned.
        assert_matches!(
            db.relays(LocalTime::now().into()).unwrap().as_slice(),
            &[(id1_, _), (id2_, _)]
            if id1_ == id1 && id2_ == id2
        );
        // But only once.
        assert_matches!(db.relays(LocalTime::now().into()).unwrap().as_slice(), &[]);
    }
}