Radish alpha
h
rad:z3gqcJUoA1n9HaHKufZs5FCSGazv5
Radicle Heartwood Protocol & Stack
Radicle
Git
heartwood crates radicle-node src control.rs
//! Client control socket implementation.
use std::io::BufReader;
use std::io::LineWriter;
use std::io::prelude::*;
use std::path::PathBuf;
use std::{io, net, time};

use radicle::storage::refs;
#[cfg(unix)]
use std::os::unix::net::{UnixListener, UnixStream};
#[cfg(windows)]
use uds_windows::{UnixListener, UnixStream};

use radicle::node::Handle;
use serde_json as json;

use crate::identity::RepoId;
use crate::node::NodeId;
use crate::node::{Command, CommandResult};
use crate::runtime;
use crate::runtime::thread;

/// Maximum timeout for waiting for node events.
const MAX_TIMEOUT: time::Duration = time::Duration::MAX;

#[derive(thiserror::Error, Debug)]
pub enum Error {
    #[error("failed to bind control socket listener: {0}")]
    Bind(io::Error),
    #[error("invalid socket path specified: {0}")]
    InvalidPath(PathBuf),
    #[error("node: {0}")]
    Node(#[from] runtime::HandleError),
}

/// Listen for commands on the control socket, and process them.
pub fn listen<E, H>(listener: UnixListener, handle: H) -> Result<(), Error>
where
    H: Handle<Error = runtime::HandleError> + 'static,
    H::Sessions: serde::Serialize,
    CommandResult<E>: From<H::Event>,
    E: serde::Serialize,
{
    log::debug!(target: "control", "Control thread listening on socket..");
    let nid = handle.nid()?;

    for incoming in listener.incoming() {
        match incoming {
            Ok(mut stream) => {
                let handle = handle.clone();

                thread::spawn(&nid, "control", move || {
                    if let Err(e) = command(&stream, handle) {
                        log::debug!(target: "control", "Command returned error: {e}");

                        CommandResult::error(e).to_writer(&mut stream).ok();

                        stream.flush().ok();
                        stream.shutdown(net::Shutdown::Both).ok();
                    }
                });
            }
            Err(e) => log::warn!(target: "control", "Failed to accept incoming connection: {e}"),
        }
    }
    log::debug!(target: "control", "Exiting control loop..");

    Ok(())
}

#[derive(thiserror::Error, Debug)]
enum CommandError {
    #[error("(de)serialization failed: {0}")]
    Serialization(#[from] json::Error),
    #[error("runtime error: {0}")]
    Runtime(#[from] runtime::HandleError),
    #[error("i/o error: {0}")]
    Io(#[from] io::Error),
}

fn command<E, H>(stream: &UnixStream, mut handle: H) -> Result<(), CommandError>
where
    H: Handle<Error = runtime::HandleError> + 'static,
    H::Sessions: serde::Serialize,
    CommandResult<E>: From<H::Event>,
    E: serde::Serialize,
{
    let mut reader = BufReader::new(stream);
    let mut writer = LineWriter::new(stream);

    let mut line = String::new();

    reader.read_line(&mut line)?;
    let input = line.trim_end();

    log::debug!(target: "control", "Received `{input}` on control socket");
    let cmd: Command = json::from_str(input)?;

    match cmd {
        Command::Connect { addr, opts } => {
            let (nid, addr) = addr.into();
            match handle.connect(nid, addr, opts) {
                Err(e) => return Err(CommandError::Runtime(e)),
                Ok(result) => {
                    json::to_writer(&mut writer, &result)?;
                    writer.write_all(b"\n")?;
                }
            }
        }
        Command::Disconnect { nid } => match handle.disconnect(nid) {
            Err(e) => return Err(CommandError::Runtime(e)),
            Ok(()) => {
                CommandResult::ok().to_writer(writer).ok();
            }
        },
        Command::Fetch {
            rid,
            nid,
            timeout,
            signed_references_minimum_feature_level,
        } => {
            fetch(
                rid,
                nid,
                timeout,
                signed_references_minimum_feature_level,
                writer,
                &mut handle,
            )?;
        }
        Command::Config => {
            let config = handle.config()?;

            CommandResult::Okay(config).to_writer(writer)?;
        }
        Command::ListenAddrs => {
            let addrs = handle.listen_addrs()?;

            CommandResult::Okay(addrs).to_writer(writer)?;
        }
        #[allow(deprecated)]
        Command::Seeds { rid } => {
            let seeds = handle.seeds(rid)?;

            CommandResult::Okay(seeds).to_writer(writer)?;
        }
        Command::SeedsFor { rid, namespaces } => {
            let seeds = handle.seeds_for(rid, namespaces)?;

            CommandResult::Okay(seeds).to_writer(writer)?;
        }
        Command::Sessions => {
            let sessions = handle.sessions()?;

            CommandResult::Okay(sessions).to_writer(writer)?;
        }
        Command::Session { nid } => {
            let session = handle.session(nid)?;

            CommandResult::Okay(session).to_writer(writer)?;
        }
        Command::Seed { rid, scope } => match handle.seed(rid, scope) {
            Ok(result) => {
                CommandResult::updated(result).to_writer(writer)?;
            }
            Err(e) => {
                return Err(CommandError::Runtime(e));
            }
        },
        Command::Unseed { rid } => match handle.unseed(rid) {
            Ok(result) => {
                CommandResult::updated(result).to_writer(writer)?;
            }
            Err(e) => {
                return Err(CommandError::Runtime(e));
            }
        },
        Command::Follow { nid, alias } => match handle.follow(nid, alias) {
            Ok(result) => {
                CommandResult::updated(result).to_writer(writer)?;
            }
            Err(e) => {
                return Err(CommandError::Runtime(e));
            }
        },
        Command::Block { nid } => match handle.block(nid) {
            Ok(result) => {
                CommandResult::updated(result).to_writer(writer)?;
            }
            Err(e) => {
                return Err(CommandError::Runtime(e));
            }
        },
        Command::Unfollow { nid } => match handle.unfollow(nid) {
            Ok(result) => {
                CommandResult::updated(result).to_writer(writer)?;
            }
            Err(e) => {
                return Err(CommandError::Runtime(e));
            }
        },
        #[allow(deprecated)]
        Command::AnnounceRefs { rid } => {
            let refs = handle.announce_refs(rid)?;

            CommandResult::Okay(refs).to_writer(writer)?;
        }
        Command::AnnounceRefsFor { rid, namespaces } => {
            let refs = handle.announce_refs_for(rid, namespaces)?;

            CommandResult::Okay(refs).to_writer(writer)?;
        }
        Command::AnnounceInventory => {
            if let Err(e) = handle.announce_inventory() {
                return Err(CommandError::Runtime(e));
            }
            CommandResult::ok().to_writer(writer).ok();
        }
        Command::AddInventory { rid } => match handle.add_inventory(rid) {
            Ok(result) => {
                CommandResult::updated(result).to_writer(writer)?;
            }
            Err(e) => {
                return Err(CommandError::Runtime(e));
            }
        },
        Command::Subscribe => match handle.subscribe(MAX_TIMEOUT) {
            Ok(events) => {
                for e in events {
                    CommandResult::from(e).to_writer(&mut writer)?;
                }
            }
            Err(e) => return Err(CommandError::Runtime(e)),
        },
        Command::Status => {
            CommandResult::ok().to_writer(writer).ok();
        }
        Command::NodeId => match handle.nid() {
            Ok(nid) => {
                CommandResult::Okay(nid).to_writer(writer)?;
            }
            Err(e) => return Err(CommandError::Runtime(e)),
        },
        Command::Debug => {
            let debug = handle.debug()?;

            CommandResult::Okay(debug).to_writer(writer)?;
        }
        Command::Shutdown => {
            log::debug!(target: "control", "Shutdown requested..");
            // Channel might already be disconnected if shutdown
            // came from somewhere else. Ignore errors.
            handle.shutdown().ok();
            CommandResult::ok().to_writer(writer).ok();
        }
    }
    Ok(())
}

fn fetch<W: Write, H: Handle<Error = runtime::HandleError>>(
    id: RepoId,
    node: NodeId,
    timeout: time::Duration,
    signed_references_minimum_feature_level: Option<refs::FeatureLevel>,
    mut writer: W,
    handle: &mut H,
) -> Result<(), CommandError> {
    match handle.fetch(id, node, timeout, signed_references_minimum_feature_level) {
        Ok(result) => {
            json::to_writer(&mut writer, &result)?;
        }
        Err(e) => {
            return Err(CommandError::Runtime(e));
        }
    }
    Ok(())
}

#[cfg(test)]
mod tests {
    use std::io::prelude::*;
    use std::thread;

    use super::*;
    use crate::identity::RepoId;
    use crate::node::Handle;
    use crate::node::policy::Scope;
    use crate::node::{Alias, Node, NodeId};
    use crate::test;

    #[test]
    fn test_control_socket() {
        let tmp = tempfile::tempdir().unwrap();
        let handle = test::handle::Handle::default();
        let socket = tmp.path().join("alice.sock");
        let rids = test::arbitrary::set::<RepoId>(1..3);
        let listener = UnixListener::bind(&socket).unwrap();
        let nid = handle.nid().unwrap();

        thread::spawn({
            let handle = handle.clone();

            move || listen(listener, handle)
        });

        for rid in &rids {
            let mut stream = loop {
                if let Ok(stream) = UnixStream::connect(&socket) {
                    break stream;
                }
            };
            writeln!(
                &mut stream,
                "{}",
                json::to_string(&Command::AnnounceRefsFor {
                    rid: rid.to_owned(),
                    namespaces: [nid].into(),
                })
                .unwrap()
            )
            .unwrap();

            let stream = BufReader::new(stream);
            let line = stream.lines().next().unwrap().unwrap();

            assert_eq!(
                line,
                json::json!({
                    "remote": handle.nid().unwrap(),
                    "at": "0000000000000000000000000000000000000000"
                })
                .to_string()
            );
        }

        for rid in &rids {
            assert!(handle.updates.lock().unwrap().contains(&(*rid, nid)));
        }
    }

    #[test]
    fn test_seed_unseed() {
        let tmp = tempfile::tempdir().unwrap();
        let socket = tmp.path().join("node.sock");
        let proj = test::arbitrary::r#gen::<RepoId>(1);
        let peer = test::arbitrary::r#gen::<NodeId>(1);
        let listener = UnixListener::bind(&socket).unwrap();
        let mut handle = Node::new(&socket);

        thread::spawn({
            let handle = crate::test::handle::Handle::default();

            move || crate::control::listen(listener, handle)
        });

        // Wait for node to be online.
        while !handle.is_running() {}

        assert!(handle.seed(proj, Scope::default()).unwrap());
        assert!(!handle.seed(proj, Scope::default()).unwrap());
        assert!(handle.unseed(proj).unwrap());
        assert!(!handle.unseed(proj).unwrap());

        assert!(handle.follow(peer, Some(Alias::new("alice"))).unwrap());
        assert!(!handle.follow(peer, Some(Alias::new("alice"))).unwrap());
        assert!(handle.unfollow(peer).unwrap());
        assert!(!handle.unfollow(peer).unwrap());
    }
}