Radish alpha
h
Radicle Heartwood Protocol & Stack
Radicle
Git (anonymous pull)
Log in to clone via SSH
node: Improve wire code
Alexis Sellier committed 3 years ago
commit 1dec499e3d33fb4d80d4cc6d359c458f0c42f1d1
parent 9d3030a0933aaf96125e1e5ea1683915ff867e2a
4 files changed +130 -57
modified node/src/protocol/message.rs
@@ -1,6 +1,6 @@
use std::net;

-
use byteorder::NetworkEndian;
+
use byteorder::{NetworkEndian, ReadBytesExt};
use serde::{Deserialize, Serialize};

use crate::crypto;
@@ -27,11 +27,77 @@ pub type NodeFeatures = [u8; 32];
// TODO: We should check the length and charset when deserializing.
pub struct Hostname(String);

+
/// Message type.
+
#[repr(u16)]
+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+
pub enum MessageType {
+
    Hello = 0,
+
    Node = 2,
+
    GetInventory = 4,
+
    Inventory = 6,
+
    RefsUpdate = 8,
+
}
+

+
impl From<MessageType> for u16 {
+
    fn from(other: MessageType) -> Self {
+
        other as u16
+
    }
+
}
+

+
impl TryFrom<u16> for MessageType {
+
    type Error = u16;
+

+
    fn try_from(other: u16) -> Result<Self, Self::Error> {
+
        match other {
+
            0 => Ok(MessageType::Hello),
+
            2 => Ok(MessageType::Node),
+
            4 => Ok(MessageType::GetInventory),
+
            6 => Ok(MessageType::Inventory),
+
            8 => Ok(MessageType::RefsUpdate),
+
            _ => Err(other),
+
        }
+
    }
+
}
+

+
/// Address type.
+
#[repr(u8)]
+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+
pub enum AddressType {
+
    Ipv4 = 1,
+
    Ipv6 = 2,
+
    Hostname = 3,
+
    Onion = 4,
+
}
+

+
impl From<AddressType> for u8 {
+
    fn from(other: AddressType) -> Self {
+
        other as u8
+
    }
+
}
+

+
impl TryFrom<u8> for AddressType {
+
    type Error = u8;
+

+
    fn try_from(other: u8) -> Result<Self, Self::Error> {
+
        match other {
+
            1 => Ok(AddressType::Ipv4),
+
            2 => Ok(AddressType::Ipv6),
+
            3 => Ok(AddressType::Hostname),
+
            4 => Ok(AddressType::Hostname),
+
            _ => Err(other),
+
        }
+
    }
+
}
+

/// Peer public protocol address.
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
pub enum Address {
-
    Ip {
-
        ip: net::IpAddr,
+
    Ipv4 {
+
        ip: net::Ipv4Addr,
+
        port: u16,
+
    },
+
    Ipv6 {
+
        ip: net::Ipv6Addr,
        port: u16,
    },
    Hostname {
@@ -47,6 +113,17 @@ pub enum Address {
    },
}

+
impl From<net::SocketAddr> for Address {
+
    fn from(other: net::SocketAddr) -> Self {
+
        let port = other.port();
+

+
        match other.ip() {
+
            net::IpAddr::V4(ip) => Self::Ipv4 { ip, port },
+
            net::IpAddr::V6(ip) => Self::Ipv6 { ip, port },
+
        }
+
    }
+
}
+

impl wire::Encode for Envelope {
    fn encode<W: std::io::Write + ?Sized>(&self, writer: &mut W) -> Result<usize, std::io::Error> {
        let mut n = 0;
@@ -72,17 +149,14 @@ impl wire::Encode for Address {
        let mut n = 0;

        match self {
-
            Self::Ip { ip, port } => {
-
                match ip {
-
                    net::IpAddr::V4(addr) => {
-
                        n += 1u8.encode(writer)?;
-
                        n += addr.octets().encode(writer)?;
-
                    }
-
                    net::IpAddr::V6(addr) => {
-
                        n += 2u8.encode(writer)?;
-
                        n += addr.octets().encode(writer)?;
-
                    }
-
                }
+
            Self::Ipv4 { ip, port } => {
+
                n += u8::from(AddressType::Ipv4).encode(writer)?;
+
                n += ip.octets().encode(writer)?;
+
                n += port.encode(writer)?;
+
            }
+
            Self::Ipv6 { ip, port } => {
+
                n += u8::from(AddressType::Ipv6).encode(writer)?;
+
                n += ip.octets().encode(writer)?;
                n += port.encode(writer)?;
            }
            Self::Hostname { .. } => todo!(),
@@ -94,26 +168,30 @@ impl wire::Encode for Address {

impl wire::Decode for Address {
    fn decode<R: std::io::Read + ?Sized>(reader: &mut R) -> Result<Self, wire::Error> {
-
        use byteorder::ReadBytesExt;
+
        let addrtype = reader.read_u8()?;

-
        match reader.read_u8()? {
-
            1 => {
+
        match AddressType::try_from(addrtype) {
+
            Ok(AddressType::Ipv4) => {
                let octets: [u8; 4] = wire::Decode::decode(reader)?;
-
                let ip = net::IpAddr::from(net::Ipv4Addr::from(octets));
+
                let ip = net::Ipv4Addr::from(octets);
                let port = u16::decode(reader)?;

-
                Ok(Self::Ip { ip, port })
+
                Ok(Self::Ipv4 { ip, port })
            }
-
            2 => {
+
            Ok(AddressType::Ipv6) => {
                let octets: [u8; 16] = wire::Decode::decode(reader)?;
-
                let ip = net::IpAddr::from(net::Ipv6Addr::from(octets));
+
                let ip = net::Ipv6Addr::from(octets);
                let port = u16::decode(reader)?;

-
                Ok(Self::Ip { ip, port })
+
                Ok(Self::Ipv6 { ip, port })
            }
-
            _ => {
+
            Ok(AddressType::Hostname) => {
                todo!();
            }
+
            Ok(AddressType::Onion) => {
+
                todo!();
+
            }
+
            Err(other) => Err(wire::Error::UnknownAddressType(other)),
        }
    }
}
@@ -223,12 +301,13 @@ impl Message {

    pub fn type_id(&self) -> u16 {
        match self {
-
            Self::Hello { .. } => 0,
-
            Self::Node { .. } => 2,
-
            Self::GetInventory { .. } => 4,
-
            Self::Inventory { .. } => 6,
-
            Self::RefsUpdate { .. } => 8,
+
            Self::Hello { .. } => MessageType::Hello,
+
            Self::Node { .. } => MessageType::Node,
+
            Self::GetInventory { .. } => MessageType::GetInventory,
+
            Self::Inventory { .. } => MessageType::Inventory,
+
            Self::RefsUpdate { .. } => MessageType::RefsUpdate,
        }
+
        .into()
    }
}

@@ -277,12 +356,10 @@ impl wire::Encode for Message {

impl wire::Decode for Message {
    fn decode<R: std::io::Read + ?Sized>(reader: &mut R) -> Result<Self, wire::Error> {
-
        use byteorder::ReadBytesExt;
-

        let type_id = reader.read_u16::<NetworkEndian>()?;

-
        match type_id {
-
            0 => {
+
        match MessageType::try_from(type_id) {
+
            Ok(MessageType::Hello) => {
                let id = NodeId::decode(reader)?;
                let timestamp = Timestamp::decode(reader)?;
                let version = u32::decode(reader)?;
@@ -297,15 +374,15 @@ impl wire::Decode for Message {
                    git,
                })
            }
-
            2 => {
+
            Ok(MessageType::Node) => {
                todo!();
            }
-
            4 => {
+
            Ok(MessageType::GetInventory) => {
                let ids = Vec::<Id>::decode(reader)?;

                Ok(Self::GetInventory { ids })
            }
-
            6 => {
+
            Ok(MessageType::Inventory) => {
                let node = NodeId::decode(reader)?;
                let inv = Vec::<Id>::decode(reader)?;
                let timestamp = Timestamp::decode(reader)?;
@@ -316,16 +393,14 @@ impl wire::Decode for Message {
                    timestamp,
                })
            }
-
            8 => {
+
            Ok(MessageType::RefsUpdate) => {
                let id = Id::decode(reader)?;
                let signer = crypto::PublicKey::decode(reader)?;
                let refs = SignedRefs::decode(reader)?;

                Ok(Self::RefsUpdate { id, signer, refs })
            }
-
            n => {
-
                todo!("Mesage type {} is not yet implemented", n);
-
            }
+
            Err(other) => Err(wire::Error::UnknownMessageType(other)),
        }
    }
}
modified node/src/protocol/wire.rs
@@ -2,7 +2,7 @@ use std::collections::BTreeMap;
use std::convert::TryFrom;
use std::ops::Deref;
use std::string::FromUtf8Error;
-
use std::{io, mem, net};
+
use std::{io, mem};

use byteorder::{NetworkEndian, ReadBytesExt, WriteBytesExt};

@@ -28,18 +28,26 @@ pub enum Error {
        url: String,
        error: git::url::parse::Error,
    },
+
    #[error("unknown address type `{0}`")]
+
    UnknownAddressType(u8),
+
    #[error("unknown message type `{0}`")]
+
    UnknownMessageType(u16),
}

impl Error {
+
    /// Whether we've reached the end of file. This will be true when we fail to decode
+
    /// a message because there's not enough data in the stream.
    pub fn is_eof(&self) -> bool {
        matches!(self, Self::Io(err) if err.kind() == io::ErrorKind::UnexpectedEof)
    }
}

+
/// Things that can be encoded as binary.
pub trait Encode {
    fn encode<W: io::Write + ?Sized>(&self, writer: &mut W) -> Result<usize, io::Error>;
}

+
/// Things that can be decoded from binary.
pub trait Decode: Sized {
    fn decode<R: io::Read + ?Sized>(reader: &mut R) -> Result<Self, Error>;
}
@@ -96,6 +104,8 @@ impl Encode for u64 {
}

impl Encode for usize {
+
    /// We encode this type to a [`u32`], since there's no need to send larger messages
+
    /// over the network.
    fn encode<W: io::Write + ?Sized>(&self, writer: &mut W) -> Result<usize, io::Error> {
        assert!(
            *self <= u32::MAX as usize,
@@ -144,15 +154,6 @@ where
    }
}

-
impl Encode for net::IpAddr {
-
    fn encode<W: io::Write + ?Sized>(&self, writer: &mut W) -> Result<usize, io::Error> {
-
        match self {
-
            net::IpAddr::V4(addr) => addr.octets().encode(writer),
-
            net::IpAddr::V6(addr) => addr.octets().encode(writer),
-
        }
-
    }
-
}
-

impl Encode for &str {
    fn encode<W: io::Write + ?Sized>(&self, writer: &mut W) -> Result<usize, io::Error> {
        assert!(self.len() <= u8::MAX as usize);
modified node/src/test/arbitrary.rs
@@ -92,15 +92,15 @@ impl Arbitrary for Message {
impl Arbitrary for Address {
    fn arbitrary(g: &mut quickcheck::Gen) -> Self {
        if bool::arbitrary(g) {
-
            Address::Ip {
-
                ip: net::IpAddr::V4(net::Ipv4Addr::from(u32::arbitrary(g))),
+
            Address::Ipv4 {
+
                ip: net::Ipv4Addr::from(u32::arbitrary(g)),
                port: u16::arbitrary(g),
            }
        } else {
            let octets: [u8; 16] = ByteArray::<16>::arbitrary(g).into_inner();

-
            Address::Ip {
-
                ip: net::IpAddr::V6(net::Ipv6Addr::from(octets)),
+
            Address::Ipv6 {
+
                ip: net::Ipv6Addr::from(octets),
                port: u16::arbitrary(g),
            }
        }
modified node/src/test/peer.rs
@@ -147,10 +147,7 @@ where
            Message::hello(
                peer.id(),
                self.local_time().as_secs(),
-
                vec![Address::Ip {
-
                    ip: remote.ip(),
-
                    port: remote.port(),
-
                }],
+
                vec![Address::from(remote)],
                git,
            ),
        );