Skip to content

Commit

Permalink
Merge pull request #246 from ikatson/peer-ip-generic
Browse files Browse the repository at this point in the history
[Refactor] Generic peer IP
  • Loading branch information
ikatson authored Oct 1, 2024
2 parents abe4cf5 + 29758c6 commit e97e26f
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 74 deletions.
7 changes: 2 additions & 5 deletions crates/librqbit/src/peer_connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,7 @@ use librqbit_core::{
};
use parking_lot::RwLock;
use peer_binary_protocol::{
extended::{
handshake::{ExtendedHandshake, YourIP},
ExtendedMessage,
},
extended::{handshake::ExtendedHandshake, ExtendedMessage, PeerIP},
serialize_piece_preamble, Handshake, Message, MessageOwned, PIECE_MESSAGE_DEFAULT_LEN,
};
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -251,7 +248,7 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
if supports_extended {
let mut my_extended = ExtendedHandshake::new();
my_extended.v = Some(ByteBuf(crate::client_name_and_version().as_bytes()));
my_extended.yourip = Some(YourIP(self.addr.ip()));
my_extended.yourip = Some(PeerIP(self.addr.ip()));
self.handler
.update_my_extended_handshake(&mut my_extended)?;
let my_extended = Message::Extended(ExtendedMessage::Handshake(my_extended));
Expand Down
78 changes: 9 additions & 69 deletions crates/peer_binary_protocol/src/extended/handshake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@ use std::{collections::HashMap, net::IpAddr};
use buffers::{ByteBuf, ByteBufT};
use bytes::Bytes;
use clone_to_owned::CloneToOwned;
use serde::{Deserialize, Deserializer, Serialize};
use serde::{Deserialize, Serialize};

use crate::{
EXTENDED_UT_METADATA_KEY, EXTENDED_UT_PEX_KEY, MY_EXTENDED_UT_METADATA, MY_EXTENDED_UT_PEX,
};

use super::PeerExtendedMessageIds;
use super::{PeerExtendedMessageIds, PeerIP4, PeerIP6, PeerIPAny};

#[derive(Deserialize, Serialize, Debug, Default)]
pub struct ExtendedHandshake<ByteBuf: ByteBufT> {
Expand All @@ -20,11 +20,11 @@ pub struct ExtendedHandshake<ByteBuf: ByteBufT> {
#[serde(skip_serializing_if = "Option::is_none")]
pub v: Option<ByteBuf>,
#[serde(skip_serializing_if = "Option::is_none")]
pub yourip: Option<YourIP>,
pub yourip: Option<PeerIPAny>,
#[serde(skip_serializing_if = "Option::is_none")]
pub ipv6: Option<ByteBuf>,
pub ipv6: Option<PeerIP6>,
#[serde(skip_serializing_if = "Option::is_none")]
pub ipv4: Option<ByteBuf>,
pub ipv4: Option<PeerIP4>,
#[serde(skip_serializing_if = "Option::is_none")]
pub reqq: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
Expand Down Expand Up @@ -72,18 +72,10 @@ where

pub fn ip_addr(&self) -> Option<IpAddr> {
if let Some(ref b) = self.ipv4 {
let b = b.as_slice();
if b.len() == 4 {
let ip_bytes: &[u8; 4] = b[0..4].try_into().unwrap(); // Safe to unwrap as we check slice length
return Some(IpAddr::from(*ip_bytes));
}
return Some(b.0.into());
}
if let Some(ref b) = self.ipv6 {
let b = b.as_slice();
if b.len() == 16 {
let ip_bytes: &[u8; 16] = b[0..16].try_into().unwrap(); // Safe to unwrap as we check slice length
return Some(IpAddr::from(*ip_bytes));
}
return Some(b.0.into());
}
None
}
Expand All @@ -106,64 +98,12 @@ where
p: self.p,
v: self.v.clone_to_owned(within_buffer),
yourip: self.yourip,
ipv6: self.ipv6.clone_to_owned(within_buffer),
ipv4: self.ipv4.clone_to_owned(within_buffer),
ipv6: self.ipv6,
ipv4: self.ipv4,
reqq: self.reqq,
metadata_size: self.metadata_size,
complete_ago: self.complete_ago,
upload_only: self.upload_only,
}
}
}

#[derive(Debug, Clone, Copy)]
pub struct YourIP(pub IpAddr);

impl Serialize for YourIP {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match self.0 {
IpAddr::V4(ipv4) => {
let buf = ipv4.octets();
serializer.serialize_bytes(&buf)
}
IpAddr::V6(ipv6) => {
let buf = ipv6.octets();
serializer.serialize_bytes(&buf)
}
}
}
}

impl<'de> Deserialize<'de> for YourIP {
fn deserialize<D>(de: D) -> Result<YourIP, D::Error>
where
D: Deserializer<'de>,
{
struct Visitor {}
impl<'de> serde::de::Visitor<'de> for Visitor {
type Value = YourIP;

fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "expecting 4 bytes of ipv4 or 16 bytes of ipv6")
}

fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
if v.len() == 4 {
let ip_bytes: &[u8; 4] = v[0..4].try_into().unwrap(); // Safe to unwrap as we check slice length
return Ok(YourIP(IpAddr::from(*ip_bytes)));
} else if v.len() == 16 {
let ip_bytes: &[u8; 16] = v[0..16].try_into().unwrap(); // Safe to unwrap as we check slice length
return Ok(YourIP(IpAddr::from(*ip_bytes)));
}
Err(E::custom("expected 4 or 16 byte address"))
}
}
de.deserialize_bytes(Visitor {})
}
}
131 changes: 131 additions & 0 deletions crates/peer_binary_protocol/src/extended/ip.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
use std::{
marker::PhantomData,
net::{IpAddr, Ipv4Addr, Ipv6Addr},
};

use serde::{Deserialize, Deserializer, Serialize};

enum IpOctets {
V4([u8; 4]),
V6([u8; 16]),
}

impl IpOctets {
fn as_slice(&self) -> &[u8] {
match &self {
IpOctets::V4(s) => s,
IpOctets::V6(s) => s,
}
}
}

#[derive(Debug, Clone, Copy)]
pub struct PeerIP<T>(pub T);
pub type PeerIP4 = PeerIP<Ipv4Addr>;
pub type PeerIP6 = PeerIP<Ipv6Addr>;
pub type PeerIPAny = PeerIP<IpAddr>;

trait IpLike: Sized {
fn octets(&self) -> IpOctets;
fn try_from_slice(b: &[u8]) -> Option<Self>;
fn expecting() -> &'static str;
}

impl IpLike for Ipv4Addr {
fn octets(&self) -> IpOctets {
IpOctets::V4(self.octets())
}

fn try_from_slice(b: &[u8]) -> Option<Self> {
let arr: [u8; 4] = b.try_into().ok()?;
Some(arr.into())
}

fn expecting() -> &'static str {
"expecting 4 bytes of ipv4"
}
}

impl IpLike for Ipv6Addr {
fn octets(&self) -> IpOctets {
IpOctets::V6(self.octets())
}

fn try_from_slice(b: &[u8]) -> Option<Self> {
let arr: [u8; 16] = b.try_into().ok()?;
Some(arr.into())
}

fn expecting() -> &'static str {
"expecting 16 bytes of ipv6"
}
}

impl IpLike for IpAddr {
fn octets(&self) -> IpOctets {
match self {
IpAddr::V4(ipv4_addr) => IpOctets::V4(ipv4_addr.octets()),
IpAddr::V6(ipv6_addr) => IpOctets::V6(ipv6_addr.octets()),
}
}

fn try_from_slice(b: &[u8]) -> Option<Self> {
match b.len() {
4 => Ipv4Addr::try_from_slice(b).map(Into::into),
16 => Ipv6Addr::try_from_slice(b).map(Into::into),
_ => None,
}
}

fn expecting() -> &'static str {
"expecting 4 or 16 bytes of ipv4 or ipv6"
}
}

impl<T> Serialize for PeerIP<T>
where
T: IpLike,
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_bytes(self.0.octets().as_slice())
}
}

impl<'de, T> Deserialize<'de> for PeerIP<T>
where
T: IpLike,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct Visitor<T> {
p: PhantomData<T>,
}
impl<'de, T> serde::de::Visitor<'de> for Visitor<T>
where
T: IpLike,
{
type Value = PeerIP<T>;

fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.write_str(T::expecting())
}

fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
T::try_from_slice(v)
.map(PeerIP)
.ok_or_else(|| E::custom(T::expecting()))
}
}
deserializer.deserialize_bytes(Visitor {
p: Default::default(),
})
}
}
4 changes: 4 additions & 0 deletions crates/peer_binary_protocol/src/extended/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ use self::{handshake::ExtendedHandshake, ut_metadata::UtMetadata};
use super::MessageDeserializeError;

pub mod handshake;
mod ip;

pub use ip::{PeerIP, PeerIP4, PeerIP6, PeerIPAny};

pub mod ut_metadata;
pub mod ut_pex;

Expand Down

0 comments on commit e97e26f

Please sign in to comment.