Radish alpha
h
Radicle Heartwood Protocol & Stack
Radicle
Git (anonymous pull)
Log in to clone via SSH
node: Add a more elaborate message extension test
cloudhead committed 1 year ago
commit 83786fbd80ee6ca9fdc19c32e5353d3439d219ab
parent 82c5884fdc7620cd4b0828a89fa8402cac024b18
2 files changed +116 -13
modified radicle-node/src/wire/frame.rs
@@ -191,16 +191,16 @@ impl TryFrom<u8> for StreamKind {
/// |                     Data                                   ...| Data (variable size)
/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
#[derive(Debug, PartialEq, Eq)]
-
pub struct Frame {
+
pub struct Frame<M = Message> {
    /// The protocol version.
    pub version: Version,
    /// The stream identifier.
    pub stream: StreamId,
    /// The frame payload.
-
    pub data: FrameData,
+
    pub data: FrameData<M>,
}

-
impl Frame {
+
impl<M> Frame<M> {
    /// Create a 'git' protocol frame.
    pub fn git(stream: StreamId, data: Vec<u8>) -> Self {
        Self {
@@ -220,14 +220,16 @@ impl Frame {
    }

    /// Create a 'gossip' protocol frame.
-
    pub fn gossip(link: Link, msg: Message) -> Self {
+
    pub fn gossip(link: Link, msg: M) -> Self {
        Self {
            version: PROTOCOL_VERSION_STRING,
            stream: StreamId::gossip(link),
            data: FrameData::Gossip(msg),
        }
    }
+
}

+
impl<M: wire::Encode> Frame<M> {
    /// Serialize frame to bytes.
    pub fn to_bytes(&self) -> Vec<u8> {
        wire::serialize(self)
@@ -236,11 +238,11 @@ impl Frame {

/// Frame payload.
#[derive(Debug, PartialEq, Eq)]
-
pub enum FrameData {
+
pub enum FrameData<M> {
    /// Control frame payload.
    Control(Control),
    /// Gossip frame payload.
-
    Gossip(Message),
+
    Gossip(M),
    /// Git frame payload. May contain packet-lines as well as packfile data.
    Git(Vec<u8>),
}
@@ -310,7 +312,7 @@ impl wire::Encode for Control {
    }
}

-
impl wire::Decode for Frame {
+
impl<M: wire::Decode> wire::Decode for Frame<M> {
    fn decode<R: io::Read + ?Sized>(reader: &mut R) -> Result<Self, wire::Error> {
        let version = Version::decode(reader)?;
        if version.number() != PROTOCOL_VERSION {
@@ -331,7 +333,7 @@ impl wire::Decode for Frame {
            Ok(StreamKind::Gossip) => {
                let data = varint::payload::decode(reader)?;
                let mut cursor = io::Cursor::new(data);
-
                let msg = Message::decode(&mut cursor)?;
+
                let msg = M::decode(&mut cursor)?;
                let frame = Frame {
                    version,
                    stream,
@@ -352,7 +354,7 @@ impl wire::Decode for Frame {
    }
}

-
impl wire::Encode for Frame {
+
impl<M: wire::Encode> wire::Encode for Frame<M> {
    fn encode<W: io::Write + ?Sized>(&self, writer: &mut W) -> Result<usize, io::Error> {
        let mut n = 0;

modified radicle-node/src/wire/protocol.rs
@@ -418,7 +418,7 @@ where
                    target: "wire", "Stream {} of {} closing with {} byte(s) sent and {} byte(s) received",
                    task.stream, task.remote, s.sent_bytes, s.received_bytes
                );
-
                let frame = Frame::control(
+
                let frame = Frame::<service::Message>::control(
                    *link,
                    frame::Control::Close {
                        stream: task.stream,
@@ -470,7 +470,7 @@ where
                ChannelEvent::Data(data) => {
                    metrics.sent_git_bytes += data.len();
                    metrics.sent_bytes += data.len();
-
                    Frame::git(stream, data)
+
                    Frame::<service::Message>::git(stream, data)
                }
                ChannelEvent::Close => Frame::control(*link, frame::Control::Close { stream }),
                ChannelEvent::Eof => Frame::control(*link, frame::Control::Eof { stream }),
@@ -1109,7 +1109,8 @@ where

                    self.actions.push_back(Action::Send(
                        fd,
-
                        Frame::control(link, frame::Control::Open { stream }).to_bytes(),
+
                        Frame::<service::Message>::control(link, frame::Control::Open { stream })
+
                            .to_bytes(),
                    ));
                }
            }
@@ -1228,7 +1229,7 @@ mod test {
    use crate::wire::varint;

    #[test]
-
    fn test_message_with_extension() {
+
    fn test_pong_message_with_extension() {
        use crate::deserializer;

        let mut stream = Vec::new();
@@ -1259,4 +1260,104 @@ mod test {
        assert!(de.deserialize_next().unwrap().is_none());
        assert!(de.is_empty());
    }
+

+
    #[test]
+
    fn test_inventory_ann_with_extension() {
+
        use crate::deserializer;
+

+
        #[derive(Debug)]
+
        struct MessageWithExt {
+
            msg: Message,
+
            ext: String,
+
        }
+

+
        impl wire::Encode for MessageWithExt {
+
            fn encode<W: io::Write + ?Sized>(&self, writer: &mut W) -> Result<usize, io::Error> {
+
                let mut n = self.msg.encode(writer)?;
+
                n += self.ext.encode(writer)?;
+

+
                Ok(n)
+
            }
+
        }
+

+
        impl wire::Decode for MessageWithExt {
+
            fn decode<R: io::Read + ?Sized>(reader: &mut R) -> Result<Self, wire::Error> {
+
                let msg = Message::decode(reader)?;
+
                let ext = String::decode(reader).unwrap_or_default();
+

+
                Ok(MessageWithExt { msg, ext })
+
            }
+
        }
+

+
        let rid = radicle::test::arbitrary::gen(1);
+
        let pk = radicle::test::arbitrary::gen(1);
+
        let sig: [u8; 64] = radicle::test::arbitrary::gen(1);
+

+
        // Message with extension.
+
        let mut stream = Vec::new();
+
        let ann = Message::announcement(
+
            pk,
+
            service::gossip::inventory(radicle::node::Timestamp::MAX, [rid]),
+
            radicle::crypto::Signature::from(sig),
+
        );
+
        let pong = Message::Pong {
+
            zeroes: ZeroBytes::new(42),
+
        };
+
        // Framed message with extension.
+
        frame::Frame::gossip(
+
            Link::Outbound,
+
            MessageWithExt {
+
                msg: ann.clone(),
+
                ext: String::from("extra"),
+
            },
+
        )
+
        .encode(&mut stream)
+
        .unwrap();
+
        // Pong message that comes after, without extension.
+
        frame::Frame::gossip(Link::Outbound, pong.clone())
+
            .encode(&mut stream)
+
            .unwrap();
+

+
        // First test deserializing using the message with extension type.
+
        {
+
            let mut de = deserializer::Deserializer::<1024, Frame<MessageWithExt>>::new(1024);
+
            de.input(&stream).unwrap();
+

+
            radicle::assert_matches!(
+
                de.deserialize_next().unwrap().unwrap().data,
+
                FrameData::Gossip(MessageWithExt {
+
                    msg,
+
                    ext,
+
                }) if msg == ann && ext == String::from("extra")
+
            );
+
            radicle::assert_matches!(
+
                de.deserialize_next().unwrap().unwrap().data,
+
                FrameData::Gossip(MessageWithExt {
+
                    msg,
+
                    ext,
+
                }) if msg == pong && ext.is_empty()
+
            );
+
            assert!(de.deserialize_next().unwrap().is_none());
+
            assert!(de.is_empty());
+
        }
+

+
        // Then test deserializing using the current message type without the extension.
+
        {
+
            let mut de = deserializer::Deserializer::<1024, Frame<Message>>::new(1024);
+
            de.input(&stream).unwrap();
+

+
            radicle::assert_matches!(
+
                de.deserialize_next().unwrap().unwrap().data,
+
                FrameData::Gossip(msg)
+
                if msg == ann
+
            );
+
            radicle::assert_matches!(
+
                de.deserialize_next().unwrap().unwrap().data,
+
                FrameData::Gossip(msg)
+
                if msg == pong
+
            );
+
            assert!(de.deserialize_next().unwrap().is_none());
+
            assert!(de.is_empty());
+
        }
+
    }
}