Radish alpha
h
rad:z3gqcJUoA1n9HaHKufZs5FCSGazv5
Radicle Heartwood Protocol & Stack
Radicle
Git
heartwood crates radicle src cob cache.rs
mod migrations;

use std::collections::HashMap;
use std::convert::Infallible;
use std::fmt;
use std::marker::PhantomData;
use std::path::Path;
use std::sync::Arc;
use std::time;

use radicle_cob::ObjectId;
use sqlite as sql;
use thiserror::Error;

use crate::prelude::RepoId;
use crate::sql::transaction;

/// File suffix for storing the COBs database file.
pub const COBS_DB_FILE: &str = "cache.db";

/// How long to wait for the database lock to be released before failing a read.
const DB_READ_TIMEOUT: time::Duration = time::Duration::from_secs(3);
/// How long to wait for the database lock to be released before failing a write.
const DB_WRITE_TIMEOUT: time::Duration = time::Duration::from_secs(6);

/// Database migrations.
/// The first migration is the creation of the initial tables.
const MIGRATIONS: &[Migration] = &[
    Migration::Sql(include_str!("cache/migrations/1.sql")),
    Migration::Native(migrations::_2::run),
    Migration::Sql(include_str!("cache/migrations/3.sql")),
];

/// Function signature for native migrations.
type MigrateFn = fn(&sql::Connection, &Progress, &mut dyn MigrateCallback) -> Result<usize, Error>;

/// A database migration.
enum Migration {
    /// Migration written in SQL.
    Sql(&'static str),
    /// Migration function written in Rust.
    Native(MigrateFn),
}

/// Progress of a database migration.
#[derive(Debug)]
pub struct MigrateProgress<'a> {
    /// Progress in the list of migrations.
    pub migration: &'a Progress,
    /// Progress within each individual migration.
    pub rows: &'a Progress,
}

impl MigrateProgress<'_> {
    /// If we're done with the migration.
    pub fn is_done(&self) -> bool {
        self.migration.current() == self.migration.total()
            && self.rows.current() == self.rows.total()
    }
}

/// Something that can process migration progress.
pub trait MigrateCallback {
    /// A migration has progressed.
    fn progress(&mut self, progress: MigrateProgress<'_>);
}

impl<F> MigrateCallback for F
where
    F: Fn(MigrateProgress),
{
    fn progress(&mut self, progress: MigrateProgress) {
        (self)(progress)
    }
}

/// Migration functions that implement [`MigrateCallback`].
pub mod migrate {
    use super::*;

    /// Log progress via installed logger at "info" level.
    pub fn log(progress: MigrateProgress<'_>) {
        log::trace!(
            target: "db",
            "Migration {}/{} in progress.. ({}%)",
            progress.migration.current(),
            progress.migration.total(),
            progress.rows.percentage()
        );
    }

    /// Ignore progress, just migrate.
    pub fn ignore(_progress: MigrateProgress<'_>) {}
}

#[derive(Error, Debug)]
pub enum Error {
    /// An Internal error.
    #[error("internal error: {0}")]
    Internal(#[from] sql::Error),
    /// Malformed JSON schema, eg. missing fields or wrong field types.
    #[error("malformed JSON schema")]
    MalformedJsonSchema,
    /// Malformed JSON data, ie. not valid JSON.
    #[error("malformed JSON data: {0}")]
    MalformedJson(serde_json::Error),
    /// No rows returned in query result.
    #[error("no rows returned")]
    NoRows,
    /// Schema is out of date, migrations need to be run.
    #[error("collaborative objects database is out of date")]
    OutOfDate,
}

/// Read from and write to the store.
pub type StoreWriter = Store<Write>;

/// Read from the store.
pub type StoreReader = Store<Read>;

/// Read-only type witness.
#[derive(Clone)]
pub struct Read;
/// Read-write type witness.
#[derive(Clone)]
pub struct Write;

/// A file-backed database storing materialized COBs.
#[derive(Clone)]
pub struct Store<T> {
    pub(super) db: Arc<sql::ConnectionThreadSafe>,
    marker: PhantomData<T>,
}

impl<T> fmt::Debug for Store<T> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("Database").finish()
    }
}

impl Store<Read> {
    /// Same as [`Self::open`], but in read-only mode. This is useful to have multiple
    /// open databases, as no locking is required.
    pub fn reader<P: AsRef<Path>>(path: P) -> Result<Self, Error> {
        let mut db = sql::Connection::open_thread_safe_with_flags(
            path,
            sqlite::OpenFlags::new().with_read_only(),
        )?;
        db.set_busy_timeout(DB_READ_TIMEOUT.as_millis() as usize)?;

        Ok(Self {
            db: Arc::new(db),
            marker: PhantomData,
        })
    }

    /// Create a new in-memory database.
    pub fn memory() -> Result<Self, Error> {
        let mut db = sql::Connection::open_thread_safe_with_flags(
            ":memory:",
            sqlite::OpenFlags::new().with_read_only(),
        )?;
        db.set_busy_timeout(DB_READ_TIMEOUT.as_millis() as usize)?;

        Ok(Self {
            db: Arc::new(db),
            marker: PhantomData,
        })
    }
}

impl Store<Write> {
    /// Open a database at the given path. Creates a new database if it
    /// doesn't exist.
    pub fn open<P: AsRef<Path>>(path: P) -> Result<Self, Error> {
        let mut db = sql::Connection::open_thread_safe(path)?;
        db.set_busy_timeout(DB_WRITE_TIMEOUT.as_millis() as usize)?;

        Ok(Self {
            db: Arc::new(db),
            marker: PhantomData,
        })
    }

    /// Create a new in-memory database.
    pub fn memory() -> Result<Self, Error> {
        let db = Arc::new(sql::Connection::open_thread_safe(":memory:")?);

        Ok(Self {
            db,
            marker: PhantomData,
        })
    }

    /// Builder method that migrates the database.
    pub fn with_migrations<M: MigrateCallback>(mut self, callback: M) -> Result<Self, Error> {
        self.migrate(callback).map(|_| self)
    }

    /// Turn this handle into a read-only handle.
    pub fn read_only(self) -> Store<Read> {
        Store {
            db: self.db,
            marker: PhantomData,
        }
    }

    /// Perform a raw query on the database handle.
    pub fn raw_query<T, E, F>(&self, query: F) -> Result<T, E>
    where
        F: FnOnce(&sql::Connection) -> Result<T, E>,
        E: From<sql::Error>,
    {
        transaction(&self.db, query)
    }

    /// Migrate this database to the latest version.
    /// Returns the version migrated to.
    pub fn migrate<M: MigrateCallback>(&mut self, callback: M) -> Result<usize, Error> {
        self.migrate_to(MIGRATIONS.len(), callback)
    }

    /// Migrate this database to the given target version.
    /// Returns the version migrated to.
    pub fn migrate_to<M: MigrateCallback>(
        &mut self,
        target: usize,
        mut callback: M,
    ) -> Result<usize, Error> {
        let db = &self.db;
        let mut version = version(db)?;
        let total = MIGRATIONS.len();

        for (i, migration) in MIGRATIONS.iter().enumerate().take(target).skip(version) {
            let current = i + 1;

            transaction(db, |db| {
                match migration {
                    Migration::Sql(query) => {
                        db.execute(query)?;
                        callback.progress(MigrateProgress {
                            migration: &Progress { total, current },
                            rows: &Progress::done(1),
                        });
                    }
                    Migration::Native(migrate) => {
                        migrate(db, &Progress { total, current }, &mut callback)?;
                    }
                }
                version = bump(db)?;

                Ok::<_, Error>(())
            })?;
        }
        Ok(version)
    }
}

impl<T> Store<T> {
    /// Get the database version. This is updated on schema changes.
    pub fn version(&self) -> Result<usize, Error> {
        version(&self.db)
    }

    /// Check if the database version is out of date, ie. we need to migrate.
    pub fn check_version(&self) -> Result<(), Error> {
        if version(&self.db)? < MIGRATIONS.len() {
            return Err(Error::OutOfDate);
        }
        Ok(())
    }
}

/// Get the `user_version` value from the database header.
pub fn version(db: &sql::Connection) -> Result<usize, Error> {
    let version = db
        .prepare("PRAGMA user_version")?
        .into_iter()
        .next()
        .ok_or(Error::NoRows)??
        .read::<i64, _>(0);

    Ok(version as usize)
}

/// Bump the `user_version` value.
fn bump(db: &sql::Connection) -> Result<usize, Error> {
    let old = version(db)?;
    let new = old + 1;

    db.execute(format!("PRAGMA user_version = {new}"))?;

    Ok(new as usize)
}

/// Update a COB object in the cache.
pub trait Update<T> {
    /// The output type, if any, for a successful update.
    type Out;
    type UpdateError: std::error::Error + Send + Sync + 'static;

    fn update(
        &mut self,
        rid: &RepoId,
        id: &ObjectId,
        object: &T,
    ) -> Result<Self::Out, Self::UpdateError>;
}

/// Remove a COB object in the cache.
pub trait Remove<T> {
    /// The output type, if any, for a successful removal.
    type Out;
    type RemoveError: std::error::Error + Send + Sync + 'static;

    /// Delete an object in the COB cache.
    ///
    /// This assumes that the `id` is unique across repositories.
    fn remove(&mut self, id: &ObjectId) -> Result<Self::Out, Self::RemoveError>;
    /// Delete all entries from a repo.
    fn remove_all(&mut self, rid: &RepoId) -> Result<Self::Out, Self::RemoveError>;
}

/// An in-memory cache for storing COB objects.
///
/// The intention is for this to be used in tests that expect cache
/// reads.
#[derive(Clone, Debug)]
pub struct InMemory<T> {
    inner: HashMap<RepoId, HashMap<ObjectId, T>>,
}

impl<T> Default for InMemory<T> {
    fn default() -> Self {
        Self {
            inner: HashMap::new(),
        }
    }
}

impl<T> Update<T> for InMemory<T>
where
    T: Clone,
{
    type Out = Option<T>;
    type UpdateError = Infallible;

    fn update(
        &mut self,
        rid: &RepoId,
        id: &ObjectId,
        object: &T,
    ) -> Result<Self::Out, Self::UpdateError> {
        let objects = self.inner.entry(*rid).or_default();
        Ok(objects.insert(*id, object.clone()))
    }
}

/// The `/dev/null` of caches.
///
/// It will ignore any updates, and successfully return on each call
/// of [`Update::update`].
///
/// The intention is for this to be used in tests that do not expect
/// any cache reads.
pub struct NoCache;

impl<T> Update<T> for NoCache {
    type Out = ();
    type UpdateError = Infallible;

    fn update(
        &mut self,
        _rid: &RepoId,
        _id: &ObjectId,
        _object: &T,
    ) -> Result<Self::Out, Self::UpdateError> {
        Ok(())
    }
}

impl<T> Remove<T> for NoCache {
    type Out = ();
    type RemoveError = Infallible;

    fn remove(&mut self, _id: &ObjectId) -> Result<Self::Out, Self::RemoveError> {
        Ok(())
    }

    fn remove_all(&mut self, _rid: &RepoId) -> Result<Self::Out, Self::RemoveError> {
        Ok(())
    }
}

/// Track the progress of cache writes when transferring the
/// repository COBs to their respective caches.
///
/// See [`crate::cob::issue::Cache::write_all`] and
/// [`crate::cob::patch::Cache::write_all`].
#[derive(Debug)]
pub struct Progress {
    current: usize,
    total: usize,
}

impl Progress {
    /// Create a new progress tracker with the given `total` amount.
    pub fn new(total: usize) -> Self {
        Self { current: 0, total }
    }

    /// Create a new progress tracker that is "done".
    pub fn done(total: usize) -> Self {
        Self {
            current: total,
            total,
        }
    }

    /// Increment the [`Progress::current`] progress.
    pub fn inc(&mut self) {
        self.current += 1;
    }

    /// Return the `total` amount.
    pub fn total(&self) -> usize {
        self.total
    }

    /// Return the `current` amount.
    pub fn current(&self) -> usize {
        self.current
    }

    /// Return the percentage of the progress made.
    pub fn percentage(&self) -> f32 {
        if self.total == 0 {
            100.
        } else {
            (self.current as f32 / self.total as f32) * 100.0
        }
    }
}

#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
    use super::*;
    use crate::assert_matches;

    #[test]
    fn test_check_version() {
        let mut db = StoreWriter::memory().unwrap();
        assert_matches!(db.check_version(), Err(Error::OutOfDate));

        db.migrate(migrate::ignore).unwrap();
        assert_matches!(db.check_version(), Ok(()));
    }

    #[test]
    fn test_migrate_to() {
        let mut db = StoreWriter::memory().unwrap();
        assert_eq!(db.version().unwrap(), 0);

        assert_eq!(db.migrate_to(1, migrate::ignore).unwrap(), 1); // 0 -> 1
        assert_eq!(db.version().unwrap(), 1);

        assert_eq!(db.migrate_to(2, migrate::ignore).unwrap(), 2); // 1 -> 2
        assert_eq!(db.version().unwrap(), 2);

        assert_eq!(db.migrate_to(3, migrate::ignore).unwrap(), 3); // 2 -> 3
        assert_eq!(db.version().unwrap(), 3);

        assert_eq!(db.migrate_to(1, migrate::ignore).unwrap(), 3); // No-op.
        assert_eq!(db.version().unwrap(), 3);

        assert_eq!(db.migrate_to(99, migrate::ignore).unwrap(), 3); // No-op.
        assert_eq!(db.version().unwrap(), 3);
    }
}