use std::marker::PhantomData;
use std::num::TryFromIntError;
use std::path::Path;
use std::sync::Arc;
use std::{fmt, io, str::FromStr, time};
use localtime::LocalTime;
use sqlite as sql;
use thiserror::Error;
use crate::git;
use crate::git::Oid;
use crate::git::RefError;
use crate::git::fmt::RefString;
use crate::prelude::RepoId;
use crate::sql::transaction;
use crate::storage::RefUpdate;
use super::{
Notification, NotificationId, NotificationKind, NotificationKindError, NotificationStatus,
};
/// How long to wait for the database lock to be released before failing a read.
const DB_READ_TIMEOUT: time::Duration = time::Duration::from_secs(3);
/// How long to wait for the database lock to be released before failing a write.
const DB_WRITE_TIMEOUT: time::Duration = time::Duration::from_secs(6);
#[derive(Error, Debug)]
pub enum Error {
/// I/O error.
#[error("i/o error: {0}")]
Io(#[from] io::Error),
/// An Internal error.
#[error("internal error: {0}")]
Internal(#[from] sql::Error),
/// Timestamp error.
#[error("invalid timestamp: {0}")]
Timestamp(#[from] TryFromIntError),
/// Invalid Git ref name.
#[error("invalid ref name: {0}")]
RefName(#[from] RefError),
/// Invalid Git ref format.
#[error("invalid ref format: {0}")]
RefFormat(#[from] crate::git::fmt::Error),
/// Invalid notification kind.
#[error("invalid notification kind: {0}")]
NotificationKind(#[from] NotificationKindError),
/// Not found.
#[error("notification {0} not found")]
NotificationNotFound(NotificationId),
/// Internal unit overflow.
#[error("the unit overflowed")]
UnitOverflow,
}
/// Read-only type witness.
#[derive(Clone)]
pub struct Read;
/// Read-write type witness.
#[derive(Clone)]
pub struct Write;
/// Notifications store.
#[derive(Clone)]
pub struct Store<T> {
db: Arc<sql::ConnectionThreadSafe>,
marker: PhantomData<T>,
}
impl<T> fmt::Debug for Store<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Store(..)")
}
}
impl Store<Read> {
const SCHEMA: &'static str = include_str!("schema.sql");
/// Same as [`Self::open`], but in read-only mode. This is useful to have multiple
/// open databases, as no locking is required.
pub fn reader<P: AsRef<Path>>(path: P) -> Result<Self, Error> {
let mut db = sql::Connection::open_thread_safe_with_flags(
path,
sqlite::OpenFlags::new().with_read_only(),
)?;
db.set_busy_timeout(DB_READ_TIMEOUT.as_millis() as usize)?;
db.execute(Self::SCHEMA)?;
Ok(Self {
db: Arc::new(db),
marker: PhantomData,
})
}
/// Create a new in-memory address book.
pub fn memory() -> Result<Self, Error> {
let db = sql::Connection::open_thread_safe_with_flags(
":memory:",
sqlite::OpenFlags::new().with_read_only(),
)?;
db.execute(Self::SCHEMA)?;
Ok(Self {
db: Arc::new(db),
marker: PhantomData,
})
}
}
impl Store<Write> {
const SCHEMA: &'static str = include_str!("schema.sql");
/// Open a policy store at the given path. Creates a new store if it
/// doesn't exist.
pub fn open<P: AsRef<Path>>(path: P) -> Result<Self, Error> {
let mut db = sql::Connection::open_thread_safe(path)?;
db.set_busy_timeout(DB_WRITE_TIMEOUT.as_millis() as usize)?;
db.execute(Self::SCHEMA)?;
Ok(Self {
db: Arc::new(db),
marker: PhantomData,
})
}
/// Create a new in-memory address book.
pub fn memory() -> Result<Self, Error> {
let db = sql::Connection::open_thread_safe(":memory:")?;
db.execute(Self::SCHEMA)?;
Ok(Self {
db: Arc::new(db),
marker: PhantomData,
})
}
/// Get a read-only version of this store.
pub fn read_only(self) -> Store<Read> {
Store {
db: self.db,
marker: PhantomData,
}
}
/// Set notification read status for the given notifications.
pub fn set_status(
&mut self,
status: NotificationStatus,
ids: &[NotificationId],
) -> Result<bool, Error> {
transaction(&self.db, |_| {
let mut stmt = self.db.prepare(
"UPDATE `repository-notifications`
SET status = ?1
WHERE rowid = ?2",
)?;
for id in ids {
stmt.bind((1, &status))?;
stmt.bind((2, *id as i64))?;
stmt.next()?;
stmt.reset()?;
}
Ok(self.db.change_count() > 0)
})
}
/// Insert a notification. Resets the status to *unread* if it already exists.
pub fn insert(
&mut self,
repo: &RepoId,
update: &RefUpdate,
timestamp: LocalTime,
) -> Result<bool, Error> {
let mut stmt = self.db.prepare(
"INSERT INTO `repository-notifications` (repo, ref, old, new, timestamp)
VALUES (?1, ?2, ?3, ?4, ?5)
ON CONFLICT DO UPDATE
SET old = ?3, new = ?4, timestamp = ?5, status = null",
)?;
let old = update.old().map(|o| o.to_string());
let new = update.new().map(|o| o.to_string());
stmt.bind((1, repo))?;
stmt.bind((2, update.name().as_str()))?;
stmt.bind((3, old.as_deref()))?;
stmt.bind((4, new.as_deref()))?;
stmt.bind((5, i64::try_from(timestamp.as_millis())?))?;
stmt.next()?;
Ok(self.db.change_count() > 0)
}
/// Delete the given notifications.
pub fn clear(&mut self, ids: &[NotificationId]) -> Result<usize, Error> {
transaction(&self.db, |_| {
let mut stmt = self
.db
.prepare("DELETE FROM `repository-notifications` WHERE rowid = ?")?;
// N.b. we need to keep the count manually since the change count
// will always be `1` because of each reset.
let mut count = 0;
for id in ids {
stmt.bind((1, *id as i64))?;
stmt.next()?;
stmt.reset()?;
count += self.db.change_count();
}
Ok(count)
})
}
/// Delete all notifications of a repo.
pub fn clear_by_repo(&mut self, repo: &RepoId) -> Result<usize, Error> {
let mut stmt = self
.db
.prepare("DELETE FROM `repository-notifications` WHERE repo = ?")?;
stmt.bind((1, repo))?;
stmt.next()?;
Ok(self.db.change_count())
}
/// Delete all notifications from all repos.
pub fn clear_all(&mut self) -> Result<usize, Error> {
self.db
.prepare("DELETE FROM `repository-notifications`")?
.next()?;
Ok(self.db.change_count())
}
}
/// `Read` methods for `Store`. This implies that a
/// `Store<Write>` can access these functions as well.
impl<T> Store<T> {
/// Get a specific notification.
pub fn get(&self, id: NotificationId) -> Result<Notification, Error> {
let mut stmt = self.db.prepare(
"SELECT rowid, repo, ref, old, new, status, timestamp
FROM `repository-notifications`
WHERE rowid = ?",
)?;
stmt.bind((1, id as i64))?;
if let Some(Ok(row)) = stmt.into_iter().next() {
return parse::notification(row);
}
Err(Error::NotificationNotFound(id))
}
/// Get all notifications.
pub fn all(&self) -> Result<impl Iterator<Item = Result<Notification, Error>> + '_, Error> {
let stmt = self.db.prepare(
"SELECT rowid, repo, ref, old, new, status, timestamp
FROM `repository-notifications`
ORDER BY timestamp DESC",
)?;
Ok(stmt.into_iter().map(move |row| {
let row = row?;
parse::notification(row)
}))
}
// Get notifications that were created between the given times: `since <= t < until`.
pub fn by_timestamp(
&self,
since: LocalTime,
until: LocalTime,
) -> Result<impl Iterator<Item = Result<Notification, Error>> + '_, Error> {
let mut stmt = self.db.prepare(
"SELECT rowid, repo, ref, old, new, status, timestamp
FROM `repository-notifications`
WHERE timestamp >= ?1 AND timestamp < ?2
ORDER BY timestamp",
)?;
let since = i64::try_from(since.as_millis())?;
let until = i64::try_from(until.as_millis())?;
stmt.bind((1, since))?;
stmt.bind((2, until))?;
Ok(stmt.into_iter().map(move |row| {
let row = row?;
parse::notification(row)
}))
}
/// Get notifications by repo.
pub fn by_repo(
&self,
repo: &RepoId,
order_by: &str,
) -> Result<impl Iterator<Item = Result<Notification, Error>> + '_ + use<'_, T>, Error> {
let mut stmt = self.db.prepare(format!(
"SELECT rowid, repo, ref, old, new, status, timestamp
FROM `repository-notifications`
WHERE repo = ?
ORDER BY {order_by} DESC",
))?;
stmt.bind((1, repo))?;
Ok(stmt.into_iter().map(move |row| {
let row = row?;
parse::notification(row)
}))
}
/// Get the total notification count.
pub fn count(&self) -> Result<usize, Error> {
let stmt = self
.db
.prepare("SELECT COUNT(*) FROM `repository-notifications`")?;
let count: i64 = stmt
.into_iter()
.next()
.expect("COUNT will always return a single row")?
.read(0);
let count: usize = count.try_into().map_err(|_| Error::UnitOverflow)?;
Ok(count)
}
/// Get the total notification count by repos.
pub fn counts_by_repo(
&self,
) -> Result<impl Iterator<Item = Result<(RepoId, usize), Error>> + '_, Error> {
let stmt = self.db.prepare(
"SELECT repo, COUNT(*) as count
FROM `repository-notifications`
GROUP BY repo",
)?;
Ok(stmt.into_iter().map(|row| {
let row = row?;
let count = row.try_read::<i64, _>("count")? as usize;
let repo = row.try_read::<RepoId, _>("repo")?;
Ok((repo, count))
}))
}
/// Get the notification count for the given repo.
pub fn count_by_repo(&self, repo: &RepoId) -> Result<usize, Error> {
let mut stmt = self
.db
.prepare("SELECT COUNT(*) FROM `repository-notifications` WHERE repo = ?")?;
stmt.bind((1, repo))?;
let count: i64 = stmt
.into_iter()
.next()
.expect("COUNT will always return a single row")?
.read(0);
let count: usize = count.try_into().map_err(|_| Error::UnitOverflow)?;
Ok(count)
}
}
mod parse {
use super::*;
pub fn notification(row: sql::Row) -> Result<Notification, Error> {
let id = row.try_read::<i64, _>("rowid")? as NotificationId;
let repo = row.try_read::<RepoId, _>("repo")?;
let refstr = row.try_read::<&str, _>("ref")?;
let status = row.try_read::<NotificationStatus, _>("status")?;
let old = row
.try_read::<Option<&str>, _>("old")?
.map(|oid| {
Oid::from_str(oid).map_err(|e| {
Error::Internal(sql::Error {
code: None,
message: Some(format!("sql: invalid oid in `old` column: {oid:?}: {e}")),
})
})
})
.unwrap_or(Ok(git::Oid::ZERO_SHA1))?;
let new = row
.try_read::<Option<&str>, _>("new")?
.map(|oid| {
Oid::from_str(oid).map_err(|e| {
Error::Internal(sql::Error {
code: None,
message: Some(format!("sql: invalid oid in `new` column: {oid:?}: {e}")),
})
})
})
.unwrap_or(Ok(git::Oid::ZERO_SHA1))?;
let update = RefUpdate::from(RefString::try_from(refstr)?, old, new);
let (namespace, qualified) = git::parse_ref(refstr)?;
let timestamp = row.try_read::<i64, _>("timestamp")?;
let timestamp = LocalTime::from_millis(timestamp as u128);
let qualified = qualified.to_owned();
let kind = NotificationKind::try_from(qualified.clone())?;
Ok(Notification {
id,
repo,
update,
remote: namespace,
qualified,
status,
kind,
timestamp,
})
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod test {
use crate::git::fmt::{qualified, refname};
use crate::{cob, node::NodeId, test::arbitrary};
use super::*;
#[test]
fn test_clear() {
let mut db = Store::open(":memory:").unwrap();
let repo = arbitrary::r#gen::<RepoId>(1);
let old = arbitrary::oid();
let time = LocalTime::from_millis(32188142);
let master = arbitrary::oid();
for i in 0..3 {
let update = RefUpdate::Updated {
name: format!("refs/heads/feature/{i}").try_into().unwrap(),
old,
new: master,
};
assert!(db.insert(&repo, &update, time).unwrap());
}
assert_eq!(db.count().unwrap(), 3);
assert_eq!(db.count_by_repo(&repo).unwrap(), 3);
db.clear_by_repo(&repo).unwrap();
assert_eq!(db.count().unwrap(), 0);
assert_eq!(db.count_by_repo(&repo).unwrap(), 0);
}
#[test]
fn test_counts_by_repo() {
let mut db = Store::open(":memory:").unwrap();
let repo1 = arbitrary::r#gen::<RepoId>(1);
let repo2 = arbitrary::r#gen::<RepoId>(1);
let oid = arbitrary::oid();
let time = LocalTime::from_millis(32188142);
let update1 = RefUpdate::Created {
name: refname!("refs/heads/feature/1"),
oid,
};
let update2 = RefUpdate::Created {
name: refname!("refs/heads/feature/2"),
oid,
};
let update3 = RefUpdate::Created {
name: refname!("refs/heads/feature/3"),
oid,
};
assert!(db.insert(&repo1, &update1, time).unwrap());
assert!(db.insert(&repo1, &update2, time).unwrap());
assert!(db.insert(&repo2, &update3, time).unwrap());
let counts = db
.counts_by_repo()
.unwrap()
.collect::<Result<std::collections::HashMap<_, _>, _>>()
.unwrap();
assert_eq!(counts.get(&repo1).unwrap(), &2);
assert_eq!(counts.get(&repo2).unwrap(), &1);
}
#[test]
fn test_branch_notifications() {
let repo = arbitrary::r#gen::<RepoId>(1);
let old = arbitrary::oid();
let master = arbitrary::oid();
let other = arbitrary::oid();
let time1 = LocalTime::from_millis(32188142);
let time2 = LocalTime::from_millis(32189874);
let time3 = LocalTime::from_millis(32189879);
let mut db = Store::open(":memory:").unwrap();
let update1 = RefUpdate::Updated {
name: refname!("refs/heads/master"),
old,
new: master,
};
let update2 = RefUpdate::Created {
name: refname!("refs/heads/other"),
oid: other,
};
let update3 = RefUpdate::Deleted {
name: refname!("refs/heads/dev"),
oid: other,
};
assert!(db.insert(&repo, &update1, time1).unwrap());
assert!(db.insert(&repo, &update2, time2).unwrap());
assert!(db.insert(&repo, &update3, time3).unwrap());
let mut notifs = db.by_repo(&repo, "timestamp").unwrap();
assert_eq!(
notifs.next().unwrap().unwrap(),
Notification {
id: 3,
repo,
remote: None,
qualified: qualified!("refs/heads/dev"),
update: update3,
kind: NotificationKind::Branch {
name: refname!("dev")
},
status: NotificationStatus::Unread,
timestamp: time3,
}
);
assert_eq!(
notifs.next().unwrap().unwrap(),
Notification {
id: 2,
repo,
remote: None,
qualified: qualified!("refs/heads/other"),
update: update2,
kind: NotificationKind::Branch {
name: refname!("other")
},
status: NotificationStatus::Unread,
timestamp: time2,
}
);
assert_eq!(
notifs.next().unwrap().unwrap(),
Notification {
id: 1,
repo,
remote: None,
qualified: qualified!("refs/heads/master"),
update: update1,
kind: NotificationKind::Branch {
name: refname!("master")
},
status: NotificationStatus::Unread,
timestamp: time1,
}
);
assert!(notifs.next().is_none());
}
#[test]
fn test_notification_status() {
let repo = arbitrary::r#gen::<RepoId>(1);
let oid = arbitrary::oid();
let time = LocalTime::from_millis(32188142);
let mut db = Store::open(":memory:").unwrap();
let update1 = RefUpdate::Created {
name: refname!("refs/heads/feature/1"),
oid,
};
let update2 = RefUpdate::Created {
name: refname!("refs/heads/feature/2"),
oid,
};
let update3 = RefUpdate::Created {
name: refname!("refs/heads/feature/3"),
oid,
};
assert!(db.insert(&repo, &update1, time).unwrap());
assert!(db.insert(&repo, &update2, time).unwrap());
assert!(db.insert(&repo, &update3, time).unwrap());
assert!(
db.set_status(NotificationStatus::ReadAt(time), &[1, 2, 3])
.unwrap()
);
let mut notifs = db.by_repo(&repo, "timestamp").unwrap();
assert_eq!(
notifs.next().unwrap().unwrap().status,
NotificationStatus::ReadAt(time),
);
assert_eq!(
notifs.next().unwrap().unwrap().status,
NotificationStatus::ReadAt(time),
);
assert_eq!(
notifs.next().unwrap().unwrap().status,
NotificationStatus::ReadAt(time),
);
}
#[test]
fn test_duplicate_notifications() {
let repo = arbitrary::r#gen::<RepoId>(1);
let old = arbitrary::oid();
let master1 = arbitrary::oid();
let master2 = arbitrary::oid();
let time1 = LocalTime::from_millis(32188142);
let time2 = LocalTime::from_millis(32189874);
let mut db = Store::open(":memory:").unwrap();
let update1 = RefUpdate::Updated {
name: refname!("refs/heads/master"),
old,
new: master1,
};
let update2 = RefUpdate::Updated {
name: refname!("refs/heads/master"),
old: master1,
new: master2,
};
assert!(db.insert(&repo, &update1, time1).unwrap());
assert!(
db.set_status(NotificationStatus::ReadAt(time1), &[1])
.unwrap()
);
assert!(db.insert(&repo, &update2, time2).unwrap());
let mut notifs = db.by_repo(&repo, "timestamp").unwrap();
assert_eq!(
notifs.next().unwrap().unwrap(),
Notification {
id: 1,
repo,
remote: None,
qualified: qualified!("refs/heads/master"),
update: update2,
kind: NotificationKind::Branch {
name: refname!("master")
},
// Status is reset to "unread".
status: NotificationStatus::Unread,
timestamp: time2,
}
);
assert!(notifs.next().is_none());
}
#[test]
fn test_cob_notifications() {
let repo = arbitrary::r#gen::<RepoId>(1);
let old = arbitrary::oid();
let new = arbitrary::oid();
let timestamp = LocalTime::from_millis(32189874);
let nid: NodeId = "z6MknSLrJoTcukLrE435hVNQT4JUhbvWLX4kUzqkEStBU8Vi"
.parse()
.unwrap();
let mut db = Store::open(":memory:").unwrap();
let qualified =
qualified!("refs/cobs/xyz.radicle.issue/d87dcfe8c2b3200e78b128d9b959cfdf7063fefe");
let namespaced = qualified.with_namespace((&nid).into());
let update = RefUpdate::Updated {
name: namespaced.to_ref_string(),
old,
new,
};
assert!(db.insert(&repo, &update, timestamp).unwrap());
let mut notifs = db.by_repo(&repo, "timestamp").unwrap();
assert_eq!(
notifs.next().unwrap().unwrap(),
Notification {
id: 1,
repo,
remote: Some(
"z6MknSLrJoTcukLrE435hVNQT4JUhbvWLX4kUzqkEStBU8Vi"
.parse()
.unwrap()
),
qualified,
update,
kind: NotificationKind::Cob {
typed_id: cob::TypedId {
type_name: cob::issue::TYPENAME.clone(),
id: "d87dcfe8c2b3200e78b128d9b959cfdf7063fefe".parse().unwrap(),
},
},
status: NotificationStatus::Unread,
timestamp,
}
);
assert!(notifs.next().is_none());
}
}