Radish alpha
h
Radicle Heartwood Protocol & Stack
Radicle
Git (anonymous pull)
Log in to clone via SSH
node: Use `BoundedVec` for inbox
cloudhead committed 2 years ago
commit 88a40297895d3c303a4df9e3e8ba5625ff96e9eb
parent f212dbb2eef289d961857480616af2b3af1d41ec
4 files changed +58 -30
modified radicle-node/src/bounded.rs
@@ -1,4 +1,7 @@
-
use std::{collections::BTreeSet, ops};
+
use std::{
+
    collections::BTreeSet,
+
    ops::{self, RangeBounds},
+
};

#[derive(thiserror::Error, Debug)]
pub enum Error {
@@ -155,6 +158,26 @@ impl<T, const N: usize> BoundedVec<T, N> {
    pub fn unbound(self) -> Vec<T> {
        self.v
    }
+

+
    /// Calls [`Vec::Drain`].
+
    pub fn drain<R: RangeBounds<usize>>(&mut self, range: R) -> std::vec::Drain<T> {
+
        self.v.drain(range)
+
    }
+
}
+

+
impl<T: Clone, const N: usize> BoundedVec<T, N> {
+
    /// Like [`Vec::extend_from_slice`] but returns an error if out of bounds.
+
    pub fn extend_from_slice(&mut self, slice: &[T]) -> Result<(), Error> {
+
        if self.len() + slice.len() > N {
+
            return Err(Error::InvalidSize {
+
                expected: N,
+
                actual: self.len() + slice.len(),
+
            });
+
        }
+
        self.v.extend_from_slice(slice);
+

+
        Ok(())
+
    }
}

impl<T, const N: usize> ops::Deref for BoundedVec<T, N> {
modified radicle-node/src/deserializer.rs
@@ -1,6 +1,8 @@
use std::io;
use std::marker::PhantomData;

+
use crate::bounded;
+
use crate::prelude::BoundedVec;
use crate::service::message::Message;
use crate::wire;

@@ -8,43 +10,46 @@ use crate::wire;
///
/// Used to for example turn a byte stream into network messages.
#[derive(Debug)]
-
pub struct Deserializer<D = Message> {
-
    unparsed: Vec<u8>,
+
pub struct Deserializer<const B: usize, D = Message> {
+
    unparsed: BoundedVec<u8, B>,
    item: PhantomData<D>,
}

-
impl<D: wire::Decode> Default for Deserializer<D> {
+
impl<const B: usize, D: wire::Decode> Default for Deserializer<B, D> {
    fn default() -> Self {
        Self::new(wire::Size::MAX as usize + 1)
    }
}

-
impl<D> From<Vec<u8>> for Deserializer<D> {
-
    fn from(unparsed: Vec<u8>) -> Self {
-
        Self {
+
impl<const B: usize, D> TryFrom<Vec<u8>> for Deserializer<B, D> {
+
    type Error = bounded::Error;
+

+
    fn try_from(unparsed: Vec<u8>) -> Result<Self, Self::Error> {
+
        BoundedVec::try_from(unparsed).map(|unparsed| Self {
            unparsed,
            item: PhantomData,
-
        }
+
        })
    }
}

-
impl<D: wire::Decode> Deserializer<D> {
+
impl<const B: usize, D: wire::Decode> Deserializer<B, D> {
    /// Create a new stream decoder.
    pub fn new(capacity: usize) -> Self {
        Self {
-
            unparsed: Vec::with_capacity(capacity),
+
            unparsed: BoundedVec::with_capacity(capacity)
+
                .expect("Deserializer::new: capacity exceeds maximum"),
            item: PhantomData,
        }
    }

    /// Input bytes into the decoder.
-
    pub fn input(&mut self, bytes: &[u8]) {
-
        self.unparsed.extend_from_slice(bytes);
+
    pub fn input(&mut self, bytes: &[u8]) -> Result<(), bounded::Error> {
+
        self.unparsed.extend_from_slice(bytes)
    }

    /// Decode and return the next message. Returns [`None`] if nothing was decoded.
    pub fn deserialize_next(&mut self) -> Result<Option<D>, wire::Error> {
-
        let mut reader = io::Cursor::new(self.unparsed.as_mut_slice());
+
        let mut reader = io::Cursor::new(self.unparsed.as_slice());

        match D::decode(&mut reader) {
            Ok(msg) => {
@@ -74,9 +79,9 @@ impl<D: wire::Decode> Deserializer<D> {
    }
}

-
impl<D: wire::Decode> io::Write for Deserializer<D> {
+
impl<const B: usize, D: wire::Decode> io::Write for Deserializer<B, D> {
    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
-
        self.input(buf);
+
        self.input(buf).map_err(|_| io::ErrorKind::OutOfMemory)?;

        Ok(buf.len())
    }
@@ -86,7 +91,7 @@ impl<D: wire::Decode> io::Write for Deserializer<D> {
    }
}

-
impl<D: wire::Decode> Iterator for Deserializer<D> {
+
impl<const B: usize, D: wire::Decode> Iterator for Deserializer<B, D> {
    type Item = Result<D, wire::Error>;

    fn next(&mut self) -> Option<Self::Item> {
@@ -106,17 +111,17 @@ mod test {

    #[test]
    fn test_decode_next() {
-
        let mut decoder = Deserializer::<String>::new(8);
+
        let mut decoder = Deserializer::<1024, String>::new(8);

-
        decoder.input(&[3, b'b']);
+
        decoder.input(&[3, b'b']).unwrap();
        assert_matches!(decoder.deserialize_next(), Ok(None));
        assert_eq!(decoder.unparsed.len(), 2);

-
        decoder.input(&[b'y']);
+
        decoder.input(&[b'y']).unwrap();
        assert_matches!(decoder.deserialize_next(), Ok(None));
        assert_eq!(decoder.unparsed.len(), 3);

-
        decoder.input(&[b'e']);
+
        decoder.input(&[b'e']).unwrap();
        assert_matches!(decoder.deserialize_next(), Ok(Some(s)) if s.as_str() == "bye");
        assert_eq!(decoder.unparsed.len(), 0);
        assert!(decoder.is_empty());
@@ -124,9 +129,9 @@ mod test {

    #[test]
    fn test_unparsed() {
-
        let mut decoder = Deserializer::<String>::new(8);
+
        let mut decoder = Deserializer::<1024, String>::new(8);

-
        decoder.input(&[3, b'b', b'y']);
+
        decoder.input(&[3, b'b', b'y']).unwrap();
        assert_eq!(decoder.unparsed().collect::<Vec<_>>(), vec![3, b'b', b'y']);
        assert!(decoder.is_empty());
    }
@@ -135,7 +140,7 @@ mod test {
    fn prop_decode_next(chunk_size: usize) {
        let mut bytes = vec![];
        let mut msgs = vec![];
-
        let mut decoder = Deserializer::<String>::new(8);
+
        let mut decoder = Deserializer::<1024, String>::new(8);

        let chunk_size = 1 + chunk_size % MSG_HELLO.len() + MSG_BYE.len();

@@ -143,7 +148,7 @@ mod test {
        bytes.extend_from_slice(MSG_BYE);

        for chunk in bytes.as_slice().chunks(chunk_size) {
-
            decoder.input(chunk);
+
            decoder.input(chunk).unwrap();

            while let Some(msg) = decoder.deserialize_next().unwrap() {
                msgs.push(msg);
modified radicle-node/src/wire/message.rs
@@ -564,7 +564,7 @@ mod tests {
    #[test]
    fn prop_message_decoder() {
        fn property(items: Vec<Message>) {
-
            let mut decoder = Deserializer::<Message>::new(8);
+
            let mut decoder = Deserializer::<1048576, Message>::new(8);

            for item in &items {
                item.encode(&mut decoder).unwrap();
modified radicle-node/src/wire/protocol.rs
@@ -201,7 +201,7 @@ enum Peer {
        addr: NetAddr<HostName>,
        link: Link,
        nid: NodeId,
-
        inbox: Deserializer<Frame>,
+
        inbox: Deserializer<MAX_INBOX_SIZE, Frame>,
        streams: Streams,
    },
    /// The peer was scheduled for disconnection. Once the transport is handed over
@@ -702,13 +702,13 @@ where
                    ..
                }) = self.peers.get_mut(&id)
                {
-
                    if inbox.len() + data.len() > MAX_INBOX_SIZE {
+
                    if inbox.input(&data).is_err() {
                        log::error!(target: "wire", "Maximum inbox size ({MAX_INBOX_SIZE}) reached for peer {nid}");
                        log::error!(target: "wire", "Unable to process messages fast enough for peer {nid}; disconnecting..");
                        self.disconnect(id, DisconnectReason::Session(session::Error::Misbehavior));
+

                        return;
                    }
-
                    inbox.input(&data);

                    loop {
                        match inbox.deserialize_next() {
@@ -1176,8 +1176,8 @@ mod test {
        // Encode gossip message using the varint-prefix format into the stream.
        varint::payload::encode(&gossip, &mut stream).unwrap();

-
        let mut de = deserializer::Deserializer::<Frame>::new(1024);
-
        de.input(&stream);
+
        let mut de = deserializer::Deserializer::<1024, Frame>::new(1024);
+
        de.input(&stream).unwrap();

        // The "pong" message decodes successfully, even though there is trailing data.
        assert_eq!(