From 59617de20d12761c3d7fe882e4efc706b53102fa Mon Sep 17 00:00:00 2001 From: Max Inden Date: Fri, 19 Aug 2022 12:11:16 +0900 Subject: [PATCH 01/39] transports/webrtc/: Test message framing sizes --- transports/webrtc/Cargo.toml | 6 +++++ transports/webrtc/build.rs | 23 +++++++++++++++++ transports/webrtc/src/lib.rs | 39 +++++++++++++++++++++++++++++ transports/webrtc/src/message.proto | 19 ++++++++++++++ 4 files changed, 87 insertions(+) create mode 100644 transports/webrtc/build.rs create mode 100644 transports/webrtc/src/message.proto diff --git a/transports/webrtc/Cargo.toml b/transports/webrtc/Cargo.toml index 4103834f240..9f412219bfd 100644 --- a/transports/webrtc/Cargo.toml +++ b/transports/webrtc/Cargo.toml @@ -23,6 +23,7 @@ libp2p-noise = { version = "0.38.0", path = "../../transports/noise" } log = "0.4" multihash = { version = "0.16", default-features = false, features = ["sha2"] } pin-project = "1.0.0" +prost = "0.10" rand = "0.8" serde = { version = "1.0", features = ["derive"] } stun = "0.4" @@ -35,6 +36,9 @@ webrtc-ice = "0.7.0" webrtc-sctp = "0.5.0" webrtc-util = { version = "0.5.4", default-features = false, features = ["conn", "vnet", "sync"] } +[build-dependencies] +prost-build = "0.10" + [dev-dependencies] anyhow = "1.0" env_logger = "0.9" @@ -43,3 +47,5 @@ libp2p-swarm = { version = "0.38.0", path = "../../swarm" } rand_core = "0.5" rcgen = "0.9" quickcheck = "1" +unsigned-varint = { version = "0.7", features = ["asynchronous_codec"] } +asynchronous-codec = { version = "0.6" } diff --git a/transports/webrtc/build.rs b/transports/webrtc/build.rs new file mode 100644 index 00000000000..3f582337a68 --- /dev/null +++ b/transports/webrtc/build.rs @@ -0,0 +1,23 @@ +// Copyright 2022 Protocol Labs. +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +fn main() { + prost_build::compile_protos(&["src/message.proto"], &["src"]).unwrap(); +} diff --git a/transports/webrtc/src/lib.rs b/transports/webrtc/src/lib.rs index a0a8c768d2f..cc5e6a3225f 100644 --- a/transports/webrtc/src/lib.rs +++ b/transports/webrtc/src/lib.rs @@ -89,3 +89,42 @@ mod in_addr; mod sdp; mod udp_mux; mod webrtc_connection; + +mod message_proto { + include!(concat!(env!("OUT_DIR"), "/webrtc.pb.rs")); +} + +#[cfg(test)] +mod tests { + use super::*; + use asynchronous_codec::Encoder; + use bytes::BytesMut; + use prost::Message; + use unsigned_varint::codec::UviBytes; + + const MAX_MSG_LEN: usize = 16384; // 16kiB + const VARINT_LEN: usize = 2; + const PROTO_OVERHEAD: usize = 5; + + #[test] + fn proto_size() { + let message = [0; MAX_MSG_LEN - VARINT_LEN - PROTO_OVERHEAD]; + + let protobuf = message_proto::Message { + flag: Some(message_proto::message::Flag::CloseWrite.into()), + message: Some(message.to_vec()), + }; + + let mut encoded_msg = BytesMut::new(); + protobuf + .encode(&mut encoded_msg) + .expect("BytesMut to have sufficient capacity."); + assert_eq!(encoded_msg.len(), message.len() + PROTO_OVERHEAD); + + let mut uvi = UviBytes::default(); + let mut dst = BytesMut::new(); + uvi.encode(encoded_msg.clone().freeze(), &mut dst).unwrap(); + assert_eq!(dst.len(), MAX_MSG_LEN); + assert_eq!(dst.len() - encoded_msg.len(), VARINT_LEN); + } +} diff --git a/transports/webrtc/src/message.proto b/transports/webrtc/src/message.proto new file mode 100644 index 00000000000..988fa762de3 --- /dev/null +++ b/transports/webrtc/src/message.proto @@ -0,0 +1,19 @@ +syntax = "proto2"; + +package webrtc.pb; + +message Message { + enum Flag { + // The local endpoint will no longer send messages. + CLOSE_WRITE = 0; + // The local endpoint will no longer read messages. + CLOSE_READ = 1; + // The local endpoint abruptly terminates the stream. The remote endpoint + // may discard any in-flight data. + RESET = 2; + } + + optional Flag flag=1; + + optional bytes message = 2; +} From 7dc3ce1c916a5be2741930424095cedf838a1cd6 Mon Sep 17 00:00:00 2001 From: Max Inden Date: Mon, 22 Aug 2022 17:34:12 +0900 Subject: [PATCH 02/39] transports/webrtc/: Implement message framing --- misc/prost-codec/src/lib.rs | 9 + transports/webrtc/Cargo.toml | 4 + .../src/connection/poll_data_channel.rs | 245 +++++++++++++++--- transports/webrtc/src/message.proto | 1 + 4 files changed, 224 insertions(+), 35 deletions(-) diff --git a/misc/prost-codec/src/lib.rs b/misc/prost-codec/src/lib.rs index 32b8c9b9577..8c797eeec15 100644 --- a/misc/prost-codec/src/lib.rs +++ b/misc/prost-codec/src/lib.rs @@ -79,3 +79,12 @@ pub enum Error { std::io::Error, ), } + +impl From for std::io::Error { + fn from(e: Error) -> Self { + match e { + Error::Decode(e) => e.into(), + Error::Io(e) => e, + } + } +} diff --git a/transports/webrtc/Cargo.toml b/transports/webrtc/Cargo.toml index 9f412219bfd..cfab229cc44 100644 --- a/transports/webrtc/Cargo.toml +++ b/transports/webrtc/Cargo.toml @@ -10,6 +10,7 @@ keywords = ["peer-to-peer", "libp2p", "networking"] categories = ["network-programming", "asynchronous"] [dependencies] +asynchronous-codec = "0.6" async-trait = "0.1" bytes = "1" fnv = "1.0" @@ -24,12 +25,15 @@ log = "0.4" multihash = { version = "0.16", default-features = false, features = ["sha2"] } pin-project = "1.0.0" prost = "0.10" +prost-codec = { version = "0.1", path = "../../misc/prost-codec" } rand = "0.8" serde = { version = "1.0", features = ["derive"] } stun = "0.4" thiserror = "1" tinytemplate = "1.2" tokio-crate = { package = "tokio", version = "1.18", features = ["net"]} +# TODO: Needed? +tokio-util = { version = "0.7", features = ["compat"] } webrtc = { version = "0.4.0", git = "https://github.com/webrtc-rs/webrtc.git" } webrtc-data = { version = "0.4.0", git = "https://github.com/webrtc-rs/data.git" } webrtc-ice = "0.7.0" diff --git a/transports/webrtc/src/connection/poll_data_channel.rs b/transports/webrtc/src/connection/poll_data_channel.rs index 08d44a83657..cb0a719c049 100644 --- a/transports/webrtc/src/connection/poll_data_channel.rs +++ b/transports/webrtc/src/connection/poll_data_channel.rs @@ -18,7 +18,12 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. +use asynchronous_codec::Framed; +use bytes::Bytes; use futures::prelude::*; +use futures::ready; +use tokio_util::compat::Compat; +use tokio_util::compat::TokioAsyncReadCompatExt; use webrtc_data::data_channel::DataChannel; use webrtc_data::data_channel::PollDataChannel as RTCPollDataChannel; @@ -28,65 +33,133 @@ use std::sync::Arc; use std::task::{Context, Poll}; /// A wrapper around [`RTCPollDataChannel`] implementing futures [`AsyncRead`] / [`AsyncWrite`]. -#[derive(Debug)] -pub struct PollDataChannel(RTCPollDataChannel); +// TODO +// #[derive(Debug)] +pub struct PollDataChannel { + io: Framed, prost_codec::Codec>, + state: State, +} + +enum State { + Open { read_buffer: Bytes }, + WriteClosed { read_buffer: Bytes }, + ReadClosed { read_buffer: Bytes }, + ReadWriteClosed { read_buffer: Bytes }, + Reset, + Poisoned, +} + +impl State { + fn handle_flag(&mut self, flag: crate::message_proto::message::Flag) { + match (std::mem::replace(self, State::Poisoned), flag) { + ( + State::Open { read_buffer } | State::WriteClosed { read_buffer }, + crate::message_proto::message::Flag::CloseRead, + ) => { + *self = State::WriteClosed { read_buffer }; + } + ( + State::ReadClosed { read_buffer } | State::ReadWriteClosed { read_buffer }, + crate::message_proto::message::Flag::CloseRead, + ) => { + *self = State::ReadWriteClosed { read_buffer }; + } + ( + State::Open { read_buffer } | State::ReadClosed { read_buffer }, + crate::message_proto::message::Flag::CloseWrite, + ) => { + *self = State::ReadClosed { read_buffer }; + } + ( + State::WriteClosed { read_buffer } | State::ReadWriteClosed { read_buffer }, + crate::message_proto::message::Flag::CloseWrite, + ) => { + *self = State::ReadWriteClosed { read_buffer }; + } + // TODO: Or do we want to return an error? + (State::Reset, _) => *self = State::Reset, + (_, crate::message_proto::message::Flag::Reset) => *self = State::Reset, + (State::Poisoned, _) => unreachable!(), + } + } + + fn read_buffer_mut(&mut self) -> Option<&mut Bytes> { + match self { + State::Open { read_buffer } => Some(read_buffer), + State::WriteClosed { read_buffer } => Some(read_buffer), + State::ReadClosed { read_buffer } => Some(read_buffer), + State::ReadWriteClosed { read_buffer } => Some(read_buffer), + State::Reset => None, + State::Poisoned => todo!(), + } + } +} impl PollDataChannel { /// Constructs a new `PollDataChannel`. pub fn new(data_channel: Arc) -> Self { - Self(RTCPollDataChannel::new(data_channel)) + Self { + io: Framed::new( + RTCPollDataChannel::new(data_channel).compat(), + // TODO: Fix MAX + prost_codec::Codec::new(usize::MAX), + ), + state: State::Open { + read_buffer: Default::default(), + }, + } } /// Get back the inner data_channel. pub fn into_inner(self) -> RTCPollDataChannel { - self.0 + self.io.into_inner().into_inner() } /// Obtain a clone of the inner data_channel. pub fn clone_inner(&self) -> RTCPollDataChannel { - self.0.clone() + self.io.get_ref().clone() } /// MessagesSent returns the number of messages sent pub fn messages_sent(&self) -> usize { - self.0.messages_sent() + self.io.get_ref().messages_sent() } /// MessagesReceived returns the number of messages received pub fn messages_received(&self) -> usize { - self.0.messages_received() + self.io.get_ref().messages_received() } /// BytesSent returns the number of bytes sent pub fn bytes_sent(&self) -> usize { - self.0.bytes_sent() + self.io.get_ref().bytes_sent() } /// BytesReceived returns the number of bytes received pub fn bytes_received(&self) -> usize { - self.0.bytes_received() + self.io.get_ref().bytes_received() } /// StreamIdentifier returns the Stream identifier associated to the stream. pub fn stream_identifier(&self) -> u16 { - self.0.stream_identifier() + self.io.get_ref().stream_identifier() } /// BufferedAmount returns the number of bytes of data currently queued to be /// sent over this stream. pub fn buffered_amount(&self) -> usize { - self.0.buffered_amount() + self.io.get_ref().buffered_amount() } /// BufferedAmountLowThreshold returns the number of bytes of buffered outgoing /// data that is considered "low." Defaults to 0. pub fn buffered_amount_low_threshold(&self) -> usize { - self.0.buffered_amount_low_threshold() + self.io.get_ref().buffered_amount_low_threshold() } /// Set the capacity of the temporary read buffer (default: 8192). pub fn set_read_buf_capacity(&mut self, capacity: usize) { - self.0.set_read_buf_capacity(capacity) + self.io.get_mut().set_read_buf_capacity(capacity) } } @@ -96,13 +169,84 @@ impl AsyncRead for PollDataChannel { cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll> { - let mut read_buf = tokio_crate::io::ReadBuf::new(buf); - futures::ready!(tokio_crate::io::AsyncRead::poll_read( - Pin::new(&mut self.0), - cx, - &mut read_buf - ))?; - Poll::Ready(Ok(read_buf.filled().len())) + loop { + if let Some(read_buffer) = self.state.read_buffer_mut() { + if !read_buffer.is_empty() { + let n = std::cmp::min(read_buffer.len(), buf.len()); + let data = read_buffer.split_to(n); + buf[0..n].copy_from_slice(&data[..]); + + return Poll::Ready(Ok(n)); + } + } + + match &mut *self { + PollDataChannel { + state: + State::Open { + ref mut read_buffer, + }, + io, + } + | PollDataChannel { + state: + State::WriteClosed { + ref mut read_buffer, + }, + io, + } => { + match ready!(io.poll_next_unpin(cx)) + .transpose() + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))? + { + Some(crate::message_proto::Message { flag, message }) => { + assert!(read_buffer.is_empty()); + if let Some(message) = message { + *read_buffer = message.into(); + } + + if let Some(flag) = flag + .map(|f| { + crate::message_proto::message::Flag::from_i32(f) + .ok_or(io::Error::new(io::ErrorKind::InvalidData, "")) + }) + .transpose()? + { + self.state.handle_flag(flag) + } + + continue; + } + None => { + self.state + .handle_flag(crate::message_proto::message::Flag::CloseWrite); + return Poll::Ready(Ok(0)); + } + } + } + PollDataChannel { + state: State::ReadClosed { .. }, + .. + } + | PollDataChannel { + state: State::ReadWriteClosed { .. }, + .. + } => return Poll::Ready(Ok(0)), + PollDataChannel { + state: State::Reset, + .. + } => { + // TODO: Is `""` valid? + return Poll::Ready(Err(io::Error::new(io::ErrorKind::ConnectionReset, ""))); + } + PollDataChannel { + state: State::Poisoned, + .. + } => { + todo!() + } + } + } } } @@ -112,28 +256,59 @@ impl AsyncWrite for PollDataChannel { cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - tokio_crate::io::AsyncWrite::poll_write(Pin::new(&mut self.0), cx, buf) + match self.state { + State::WriteClosed { .. } | State::ReadWriteClosed { .. } => return Poll::Ready(Ok(0)), + State::Reset => { + return Poll::Ready(Err(io::Error::new(io::ErrorKind::ConnectionReset, ""))); + } + State::Open { .. } => {} + State::ReadClosed { .. } => {} + State::Poisoned => todo!(), + } + + ready!(self.io.poll_ready_unpin(cx))?; + + Pin::new(&mut self.io).start_send(crate::message_proto::Message { + flag: None, + message: Some(buf.into()), + })?; + + Poll::Ready(Ok(buf.len())) } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - tokio_crate::io::AsyncWrite::poll_flush(Pin::new(&mut self.0), cx) + // TODO: Double check that we don't have to depend on self.state here. + self.io.poll_flush_unpin(cx).map_err(Into::into) } fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - tokio_crate::io::AsyncWrite::poll_shutdown(Pin::new(&mut self.0), cx) - } + match &self.state { + State::WriteClosed { .. } | State::ReadWriteClosed { .. } => {} + State::Open { .. } | State::ReadClosed { .. } => { + ready!(self.io.poll_ready_unpin(cx))?; + Pin::new(&mut self.io).start_send(crate::message_proto::Message { + flag: Some(crate::message_proto::message::Flag::CloseWrite.into()), + message: None, + })?; - fn poll_write_vectored( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[io::IoSlice<'_>], - ) -> Poll> { - tokio_crate::io::AsyncWrite::poll_write_vectored(Pin::new(&mut self.0), cx, bufs) - } -} + match std::mem::replace(&mut self.state, State::Poisoned) { + State::Open { read_buffer } => self.state = State::WriteClosed { read_buffer }, + State::ReadClosed { read_buffer } => { + self.state = State::ReadWriteClosed { read_buffer } + } + State::WriteClosed { .. } + | State::ReadWriteClosed { .. } + | State::Reset + | State::Poisoned => unreachable!(), + } + } + State::Reset => { + return Poll::Ready(Err(io::Error::new(io::ErrorKind::ConnectionReset, ""))); + } + State::Poisoned => todo!(), + } -impl Clone for PollDataChannel { - fn clone(&self) -> PollDataChannel { - PollDataChannel(self.clone_inner()) + // TODO: Is flush the correct thing here? We don't want the underlying layer to close both write and read. + self.io.poll_flush_unpin(cx).map_err(Into::into) } } diff --git a/transports/webrtc/src/message.proto b/transports/webrtc/src/message.proto index 988fa762de3..81890c14f17 100644 --- a/transports/webrtc/src/message.proto +++ b/transports/webrtc/src/message.proto @@ -4,6 +4,7 @@ package webrtc.pb; message Message { enum Flag { + // TODO: Change to sender // The local endpoint will no longer send messages. CLOSE_WRITE = 0; // The local endpoint will no longer read messages. From 503e32fb99098c3dc1b7e7f84e82749c8a4d87f2 Mon Sep 17 00:00:00 2001 From: Max Inden Date: Mon, 22 Aug 2022 17:56:03 +0900 Subject: [PATCH 03/39] transports/webrtc: Update protobuf --- transports/webrtc/src/message.proto | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/transports/webrtc/src/message.proto b/transports/webrtc/src/message.proto index 81890c14f17..27de2b7c45f 100644 --- a/transports/webrtc/src/message.proto +++ b/transports/webrtc/src/message.proto @@ -4,10 +4,9 @@ package webrtc.pb; message Message { enum Flag { - // TODO: Change to sender - // The local endpoint will no longer send messages. + // The sender will no longer send messages. CLOSE_WRITE = 0; - // The local endpoint will no longer read messages. + // The sender will no longer read messages. CLOSE_READ = 1; // The local endpoint abruptly terminates the stream. The remote endpoint // may discard any in-flight data. From 55da918052a7c6ee864972c7166b665e29af5d90 Mon Sep 17 00:00:00 2001 From: Max Inden Date: Mon, 5 Sep 2022 12:26:01 +0900 Subject: [PATCH 04/39] transports/webrtc/: Change semantic of RESET With https://github.com/mxinden/specs/pull/1/commits/865f4f2dea8872c8de301a16b59000ac4540f18d the RESET no longer resets both write and read part of a stream, but only the former. --- .../src/connection/poll_data_channel.rs | 84 +++++++++++++------ transports/webrtc/src/message.proto | 13 +-- 2 files changed, 67 insertions(+), 30 deletions(-) diff --git a/transports/webrtc/src/connection/poll_data_channel.rs b/transports/webrtc/src/connection/poll_data_channel.rs index 19f6aec2061..c2df00a9ac8 100644 --- a/transports/webrtc/src/connection/poll_data_channel.rs +++ b/transports/webrtc/src/connection/poll_data_channel.rs @@ -45,40 +45,72 @@ enum State { WriteClosed { read_buffer: Bytes }, ReadClosed { read_buffer: Bytes }, ReadWriteClosed { read_buffer: Bytes }, - Reset, + ReadReset, + ReadResetWriteClosed, Poisoned, } impl State { fn handle_flag(&mut self, flag: crate::message_proto::message::Flag) { match (std::mem::replace(self, State::Poisoned), flag) { + // StopSending ( State::Open { read_buffer } | State::WriteClosed { read_buffer }, - crate::message_proto::message::Flag::CloseRead, + crate::message_proto::message::Flag::StopSending, ) => { *self = State::WriteClosed { read_buffer }; } + ( State::ReadClosed { read_buffer } | State::ReadWriteClosed { read_buffer }, - crate::message_proto::message::Flag::CloseRead, + crate::message_proto::message::Flag::StopSending, ) => { *self = State::ReadWriteClosed { read_buffer }; } + + ( + State::ReadReset | State::ReadResetWriteClosed, + crate::message_proto::message::Flag::StopSending, + ) => { + *self = State::ReadResetWriteClosed; + } + + // Fin ( State::Open { read_buffer } | State::ReadClosed { read_buffer }, - crate::message_proto::message::Flag::CloseWrite, + crate::message_proto::message::Flag::Fin, ) => { *self = State::ReadClosed { read_buffer }; } + ( State::WriteClosed { read_buffer } | State::ReadWriteClosed { read_buffer }, - crate::message_proto::message::Flag::CloseWrite, + crate::message_proto::message::Flag::Fin, ) => { *self = State::ReadWriteClosed { read_buffer }; } - // TODO: Or do we want to return an error? - (State::Reset, _) => *self = State::Reset, - (_, crate::message_proto::message::Flag::Reset) => *self = State::Reset, + + (State::ReadReset, crate::message_proto::message::Flag::Fin) => { + *self = State::ReadReset + } + + (State::ReadResetWriteClosed, crate::message_proto::message::Flag::Fin) => { + *self = State::ReadResetWriteClosed + } + + // Reset + ( + State::ReadClosed { .. } | State::ReadReset | State::Open { .. }, + crate::message_proto::message::Flag::Reset, + ) => *self = State::ReadReset, + + ( + State::ReadWriteClosed { .. } + | State::WriteClosed { .. } + | State::ReadResetWriteClosed, + crate::message_proto::message::Flag::Reset, + ) => *self = State::ReadResetWriteClosed, + (State::Poisoned, _) => unreachable!(), } } @@ -89,7 +121,8 @@ impl State { State::WriteClosed { read_buffer } => Some(read_buffer), State::ReadClosed { read_buffer } => Some(read_buffer), State::ReadWriteClosed { read_buffer } => Some(read_buffer), - State::Reset => None, + State::ReadReset => None, + State::ReadResetWriteClosed => None, State::Poisoned => todo!(), } } @@ -219,7 +252,7 @@ impl AsyncRead for PollDataChannel { } None => { self.state - .handle_flag(crate::message_proto::message::Flag::CloseWrite); + .handle_flag(crate::message_proto::message::Flag::Fin); return Poll::Ready(Ok(0)); } } @@ -233,7 +266,7 @@ impl AsyncRead for PollDataChannel { .. } => return Poll::Ready(Ok(0)), PollDataChannel { - state: State::Reset, + state: State::ReadReset | State::ReadResetWriteClosed, .. } => { // TODO: Is `""` valid? @@ -257,11 +290,10 @@ impl AsyncWrite for PollDataChannel { buf: &[u8], ) -> Poll> { match self.state { - State::WriteClosed { .. } | State::ReadWriteClosed { .. } => return Poll::Ready(Ok(0)), - State::Reset => { - return Poll::Ready(Err(io::Error::new(io::ErrorKind::ConnectionReset, ""))); - } - State::Open { .. } => {} + State::WriteClosed { .. } + | State::ReadWriteClosed { .. } + | State::ReadResetWriteClosed => return Poll::Ready(Ok(0)), + State::Open { .. } | State::ReadReset => {} State::ReadClosed { .. } => {} State::Poisoned => todo!(), } @@ -283,11 +315,14 @@ impl AsyncWrite for PollDataChannel { fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match &self.state { - State::WriteClosed { .. } | State::ReadWriteClosed { .. } => {} - State::Open { .. } | State::ReadClosed { .. } => { + State::WriteClosed { .. } + | State::ReadWriteClosed { .. } + | State::ReadResetWriteClosed { .. } => {} + + State::Open { .. } | State::ReadClosed { .. } | State::ReadReset => { ready!(self.io.poll_ready_unpin(cx))?; Pin::new(&mut self.io).start_send(crate::message_proto::Message { - flag: Some(crate::message_proto::message::Flag::CloseWrite.into()), + flag: Some(crate::message_proto::message::Flag::Fin.into()), message: None, })?; @@ -296,15 +331,16 @@ impl AsyncWrite for PollDataChannel { State::ReadClosed { read_buffer } => { self.state = State::ReadWriteClosed { read_buffer } } + State::ReadReset => self.state = State::ReadResetWriteClosed, State::WriteClosed { .. } | State::ReadWriteClosed { .. } - | State::Reset - | State::Poisoned => unreachable!(), + | State::ReadResetWriteClosed + | State::Poisoned => { + unreachable!() + } } } - State::Reset => { - return Poll::Ready(Err(io::Error::new(io::ErrorKind::ConnectionReset, ""))); - } + State::Poisoned => todo!(), } diff --git a/transports/webrtc/src/message.proto b/transports/webrtc/src/message.proto index 27de2b7c45f..eab3ceb720b 100644 --- a/transports/webrtc/src/message.proto +++ b/transports/webrtc/src/message.proto @@ -4,12 +4,13 @@ package webrtc.pb; message Message { enum Flag { - // The sender will no longer send messages. - CLOSE_WRITE = 0; - // The sender will no longer read messages. - CLOSE_READ = 1; - // The local endpoint abruptly terminates the stream. The remote endpoint - // may discard any in-flight data. + // The sender will no longer send messages on the stream. + FIN = 0; + // The sender will no longer read messages on the stream. Incoming data is + // being discarded on receipt. + STOP_SENDING = 1; + // The sender abruptly terminates the sending part of the stream. The + // receiver can discard any data that it already received on that stream. RESET = 2; } From 11c016f27181fc0cd172ff97a10996c418db6196 Mon Sep 17 00:00:00 2001 From: Max Inden Date: Sat, 10 Sep 2022 20:19:05 +0900 Subject: [PATCH 05/39] transports/webrtc/: Import message_proto types --- .../src/connection/poll_data_channel.rs | 53 ++++++++----------- 1 file changed, 22 insertions(+), 31 deletions(-) diff --git a/transports/webrtc/src/connection/poll_data_channel.rs b/transports/webrtc/src/connection/poll_data_channel.rs index c2df00a9ac8..d5694ff7cba 100644 --- a/transports/webrtc/src/connection/poll_data_channel.rs +++ b/transports/webrtc/src/connection/poll_data_channel.rs @@ -32,11 +32,14 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; +use crate::message_proto::message::Flag; +use crate::message_proto::Message; + /// A wrapper around [`RTCPollDataChannel`] implementing futures [`AsyncRead`] / [`AsyncWrite`]. // TODO // #[derive(Debug)] pub struct PollDataChannel { - io: Framed, prost_codec::Codec>, + io: Framed, prost_codec::Codec>, state: State, } @@ -51,64 +54,53 @@ enum State { } impl State { - fn handle_flag(&mut self, flag: crate::message_proto::message::Flag) { + fn handle_flag(&mut self, flag: Flag) { match (std::mem::replace(self, State::Poisoned), flag) { // StopSending ( State::Open { read_buffer } | State::WriteClosed { read_buffer }, - crate::message_proto::message::Flag::StopSending, + Flag::StopSending, ) => { *self = State::WriteClosed { read_buffer }; } ( State::ReadClosed { read_buffer } | State::ReadWriteClosed { read_buffer }, - crate::message_proto::message::Flag::StopSending, + Flag::StopSending, ) => { *self = State::ReadWriteClosed { read_buffer }; } - ( - State::ReadReset | State::ReadResetWriteClosed, - crate::message_proto::message::Flag::StopSending, - ) => { + (State::ReadReset | State::ReadResetWriteClosed, Flag::StopSending) => { *self = State::ReadResetWriteClosed; } // Fin - ( - State::Open { read_buffer } | State::ReadClosed { read_buffer }, - crate::message_proto::message::Flag::Fin, - ) => { + (State::Open { read_buffer } | State::ReadClosed { read_buffer }, Flag::Fin) => { *self = State::ReadClosed { read_buffer }; } ( State::WriteClosed { read_buffer } | State::ReadWriteClosed { read_buffer }, - crate::message_proto::message::Flag::Fin, + Flag::Fin, ) => { *self = State::ReadWriteClosed { read_buffer }; } - (State::ReadReset, crate::message_proto::message::Flag::Fin) => { - *self = State::ReadReset - } + (State::ReadReset, Flag::Fin) => *self = State::ReadReset, - (State::ReadResetWriteClosed, crate::message_proto::message::Flag::Fin) => { - *self = State::ReadResetWriteClosed - } + (State::ReadResetWriteClosed, Flag::Fin) => *self = State::ReadResetWriteClosed, // Reset - ( - State::ReadClosed { .. } | State::ReadReset | State::Open { .. }, - crate::message_proto::message::Flag::Reset, - ) => *self = State::ReadReset, + (State::ReadClosed { .. } | State::ReadReset | State::Open { .. }, Flag::Reset) => { + *self = State::ReadReset + } ( State::ReadWriteClosed { .. } | State::WriteClosed { .. } | State::ReadResetWriteClosed, - crate::message_proto::message::Flag::Reset, + Flag::Reset, ) => *self = State::ReadResetWriteClosed, (State::Poisoned, _) => unreachable!(), @@ -232,7 +224,7 @@ impl AsyncRead for PollDataChannel { .transpose() .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))? { - Some(crate::message_proto::Message { flag, message }) => { + Some(Message { flag, message }) => { assert!(read_buffer.is_empty()); if let Some(message) = message { *read_buffer = message.into(); @@ -240,7 +232,7 @@ impl AsyncRead for PollDataChannel { if let Some(flag) = flag .map(|f| { - crate::message_proto::message::Flag::from_i32(f) + Flag::from_i32(f) .ok_or(io::Error::new(io::ErrorKind::InvalidData, "")) }) .transpose()? @@ -251,8 +243,7 @@ impl AsyncRead for PollDataChannel { continue; } None => { - self.state - .handle_flag(crate::message_proto::message::Flag::Fin); + self.state.handle_flag(Flag::Fin); return Poll::Ready(Ok(0)); } } @@ -300,7 +291,7 @@ impl AsyncWrite for PollDataChannel { ready!(self.io.poll_ready_unpin(cx))?; - Pin::new(&mut self.io).start_send(crate::message_proto::Message { + Pin::new(&mut self.io).start_send(Message { flag: None, message: Some(buf.into()), })?; @@ -321,8 +312,8 @@ impl AsyncWrite for PollDataChannel { State::Open { .. } | State::ReadClosed { .. } | State::ReadReset => { ready!(self.io.poll_ready_unpin(cx))?; - Pin::new(&mut self.io).start_send(crate::message_proto::Message { - flag: Some(crate::message_proto::message::Flag::Fin.into()), + Pin::new(&mut self.io).start_send(Message { + flag: Some(Flag::Fin.into()), message: None, })?; From 1a6e4bd81efe7aed4dd861520cc7c92756288ba2 Mon Sep 17 00:00:00 2001 From: Max Inden Date: Sat, 10 Sep 2022 20:25:56 +0900 Subject: [PATCH 06/39] transports/webrtc/: Refactor AsyncRead match arm --- .../src/connection/poll_data_channel.rs | 96 ++++++++----------- 1 file changed, 38 insertions(+), 58 deletions(-) diff --git a/transports/webrtc/src/connection/poll_data_channel.rs b/transports/webrtc/src/connection/poll_data_channel.rs index d5694ff7cba..cb34e0e99f3 100644 --- a/transports/webrtc/src/connection/poll_data_channel.rs +++ b/transports/webrtc/src/connection/poll_data_channel.rs @@ -205,69 +205,49 @@ impl AsyncRead for PollDataChannel { } } - match &mut *self { - PollDataChannel { - state: - State::Open { - ref mut read_buffer, - }, - io, - } - | PollDataChannel { - state: - State::WriteClosed { - ref mut read_buffer, - }, - io, - } => { - match ready!(io.poll_next_unpin(cx)) - .transpose() - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))? - { - Some(Message { flag, message }) => { - assert!(read_buffer.is_empty()); - if let Some(message) = message { - *read_buffer = message.into(); - } - - if let Some(flag) = flag - .map(|f| { - Flag::from_i32(f) - .ok_or(io::Error::new(io::ErrorKind::InvalidData, "")) - }) - .transpose()? - { - self.state.handle_flag(flag) - } - - continue; - } - None => { - self.state.handle_flag(Flag::Fin); - return Poll::Ready(Ok(0)); - } - } + let PollDataChannel { state, io } = &mut *self; + + let read_buffer = match state { + State::Open { + ref mut read_buffer, } - PollDataChannel { - state: State::ReadClosed { .. }, - .. + | State::WriteClosed { + ref mut read_buffer, + } => read_buffer, + State::ReadClosed { .. } | State::ReadWriteClosed { .. } => { + return Poll::Ready(Ok(0)) } - | PollDataChannel { - state: State::ReadWriteClosed { .. }, - .. - } => return Poll::Ready(Ok(0)), - PollDataChannel { - state: State::ReadReset | State::ReadResetWriteClosed, - .. - } => { + State::ReadReset | State::ReadResetWriteClosed => { // TODO: Is `""` valid? return Poll::Ready(Err(io::Error::new(io::ErrorKind::ConnectionReset, ""))); } - PollDataChannel { - state: State::Poisoned, - .. - } => { - todo!() + State::Poisoned => todo!(), + }; + + match ready!(io.poll_next_unpin(cx)) + .transpose() + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))? + { + Some(Message { flag, message }) => { + assert!(read_buffer.is_empty()); + if let Some(message) = message { + *read_buffer = message.into(); + } + + if let Some(flag) = flag + .map(|f| { + Flag::from_i32(f).ok_or(io::Error::new(io::ErrorKind::InvalidData, "")) + }) + .transpose()? + { + self.state.handle_flag(flag) + } + + continue; + } + None => { + self.state.handle_flag(Flag::Fin); + return Poll::Ready(Ok(0)); } } } From d46a171ae11d6984480289deded633bde2bb57e2 Mon Sep 17 00:00:00 2001 From: Max Inden Date: Sat, 10 Sep 2022 22:38:45 +0900 Subject: [PATCH 07/39] transports/webrtc/: Handle flags when read side closed --- .../src/connection/poll_data_channel.rs | 70 ++++++++++++++----- 1 file changed, 52 insertions(+), 18 deletions(-) diff --git a/transports/webrtc/src/connection/poll_data_channel.rs b/transports/webrtc/src/connection/poll_data_channel.rs index cb34e0e99f3..75153d50091 100644 --- a/transports/webrtc/src/connection/poll_data_channel.rs +++ b/transports/webrtc/src/connection/poll_data_channel.rs @@ -186,6 +186,27 @@ impl PollDataChannel { pub fn set_read_buf_capacity(&mut self, capacity: usize) { self.io.get_mut().set_read_buf_capacity(capacity) } + + fn io_poll_next( + io: &mut Framed, prost_codec::Codec>, + cx: &mut Context<'_>, + ) -> Poll, Option>)>>> { + match ready!(io.poll_next_unpin(cx)) + .transpose() + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))? + { + Some(Message { flag, message }) => { + let flag = flag + .map(|f| { + Flag::from_i32(f).ok_or(io::Error::new(io::ErrorKind::InvalidData, "")) + }) + .transpose()?; + + Poll::Ready(Ok(Some((flag, message)))) + } + None => Poll::Ready(Ok(None)), + } + } } impl AsyncRead for PollDataChannel { @@ -214,8 +235,10 @@ impl AsyncRead for PollDataChannel { | State::WriteClosed { ref mut read_buffer, } => read_buffer, - State::ReadClosed { .. } | State::ReadWriteClosed { .. } => { - return Poll::Ready(Ok(0)) + State::ReadClosed { read_buffer, .. } + | State::ReadWriteClosed { read_buffer, .. } => { + assert!(read_buffer.is_empty()); + return Poll::Ready(Ok(0)); } State::ReadReset | State::ReadResetWriteClosed => { // TODO: Is `""` valid? @@ -224,26 +247,16 @@ impl AsyncRead for PollDataChannel { State::Poisoned => todo!(), }; - match ready!(io.poll_next_unpin(cx)) - .transpose() - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))? - { - Some(Message { flag, message }) => { + match ready!(Self::io_poll_next(io, cx))? { + Some((flag, message)) => { assert!(read_buffer.is_empty()); if let Some(message) = message { *read_buffer = message.into(); } - if let Some(flag) = flag - .map(|f| { - Flag::from_i32(f).ok_or(io::Error::new(io::ErrorKind::InvalidData, "")) - }) - .transpose()? - { + if let Some(flag) = flag { self.state.handle_flag(flag) - } - - continue; + }; } None => { self.state.handle_flag(Flag::Fin); @@ -260,12 +273,33 @@ impl AsyncWrite for PollDataChannel { cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { + // Handle flags iff read side closed. + loop { + match self.state { + State::ReadClosed { .. } | State::ReadReset => + // TODO: In case AsyncRead::poll_read encountered an error or returned None earlier, we will poll the + // underlying I/O resource once more. Is that allowed? How about introducing a state IoReadClosed? + { + match Self::io_poll_next(&mut self.io, cx)? { + Poll::Ready(Some((Some(flag), message))) => { + // Read side is closed. Discard any incoming messages. + drop(message); + // But still handle flags, e.g. a `Flag::StopSending`. + self.state.handle_flag(flag) + } + Poll::Ready(Some((None, message))) => drop(message), + Poll::Ready(None) | Poll::Pending => break, + } + } + _ => break, + } + } + match self.state { State::WriteClosed { .. } | State::ReadWriteClosed { .. } | State::ReadResetWriteClosed => return Poll::Ready(Ok(0)), - State::Open { .. } | State::ReadReset => {} - State::ReadClosed { .. } => {} + State::Open { .. } | State::ReadClosed { .. } | State::ReadReset => {} State::Poisoned => todo!(), } From 9cd4ef7584c2e83e768b6dc728d911a503f1f827 Mon Sep 17 00:00:00 2001 From: Max Inden Date: Sun, 11 Sep 2022 18:51:23 +0900 Subject: [PATCH 08/39] transports/webrtc: Enforce maximum message length --- .../src/connection/poll_data_channel.rs | 48 ++++++++++++++++++- transports/webrtc/src/lib.rs | 35 -------------- 2 files changed, 46 insertions(+), 37 deletions(-) diff --git a/transports/webrtc/src/connection/poll_data_channel.rs b/transports/webrtc/src/connection/poll_data_channel.rs index 75153d50091..e805875dd98 100644 --- a/transports/webrtc/src/connection/poll_data_channel.rs +++ b/transports/webrtc/src/connection/poll_data_channel.rs @@ -35,6 +35,12 @@ use std::task::{Context, Poll}; use crate::message_proto::message::Flag; use crate::message_proto::Message; +// TODO: Document +const MAX_MSG_LEN: usize = 16384; // 16kiB +const VARINT_LEN: usize = 2; +const PROTO_OVERHEAD: usize = 5; +const MAX_DATA_LEN: usize = MAX_MSG_LEN - VARINT_LEN - PROTO_OVERHEAD; + /// A wrapper around [`RTCPollDataChannel`] implementing futures [`AsyncRead`] / [`AsyncWrite`]. // TODO // #[derive(Debug)] @@ -305,12 +311,14 @@ impl AsyncWrite for PollDataChannel { ready!(self.io.poll_ready_unpin(cx))?; + let n = usize::min(buf.len(), MAX_DATA_LEN); + Pin::new(&mut self.io).start_send(Message { flag: None, - message: Some(buf.into()), + message: Some(buf[0..n].into()), })?; - Poll::Ready(Ok(buf.len())) + Poll::Ready(Ok(n)) } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { @@ -353,3 +361,39 @@ impl AsyncWrite for PollDataChannel { self.io.poll_flush_unpin(cx).map_err(Into::into) } } + +#[cfg(test)] +mod tests { + use super::*; + use asynchronous_codec::Encoder; + use bytes::BytesMut; + use prost::Message; + use unsigned_varint::codec::UviBytes; + + #[test] + fn max_data_len() { + // Largest possible message. + let message = [0; MAX_DATA_LEN]; + + let protobuf = crate::message_proto::Message { + flag: Some(crate::message_proto::message::Flag::Fin.into()), + message: Some(message.to_vec()), + }; + + let mut encoded_msg = BytesMut::new(); + protobuf + .encode(&mut encoded_msg) + .expect("BytesMut to have sufficient capacity."); + assert_eq!(encoded_msg.len(), message.len() + PROTO_OVERHEAD); + + let mut uvi = UviBytes::default(); + let mut dst = BytesMut::new(); + uvi.encode(encoded_msg.clone().freeze(), &mut dst).unwrap(); + + // Ensure the varint prefixed and protobuf encoded largest message is no longer than the + // maximum limit specified in the libp2p WebRTC specification. + assert_eq!(dst.len(), MAX_MSG_LEN); + + assert_eq!(dst.len() - encoded_msg.len(), VARINT_LEN); + } +} diff --git a/transports/webrtc/src/lib.rs b/transports/webrtc/src/lib.rs index cffabe3c6fd..4d5615353b4 100644 --- a/transports/webrtc/src/lib.rs +++ b/transports/webrtc/src/lib.rs @@ -94,38 +94,3 @@ mod webrtc_connection; mod message_proto { include!(concat!(env!("OUT_DIR"), "/webrtc.pb.rs")); } - -#[cfg(test)] -mod tests { - use super::*; - use asynchronous_codec::Encoder; - use bytes::BytesMut; - use prost::Message; - use unsigned_varint::codec::UviBytes; - - const MAX_MSG_LEN: usize = 16384; // 16kiB - const VARINT_LEN: usize = 2; - const PROTO_OVERHEAD: usize = 5; - - #[test] - fn proto_size() { - let message = [0; MAX_MSG_LEN - VARINT_LEN - PROTO_OVERHEAD]; - - let protobuf = message_proto::Message { - flag: Some(message_proto::message::Flag::CloseWrite.into()), - message: Some(message.to_vec()), - }; - - let mut encoded_msg = BytesMut::new(); - protobuf - .encode(&mut encoded_msg) - .expect("BytesMut to have sufficient capacity."); - assert_eq!(encoded_msg.len(), message.len() + PROTO_OVERHEAD); - - let mut uvi = UviBytes::default(); - let mut dst = BytesMut::new(); - uvi.encode(encoded_msg.clone().freeze(), &mut dst).unwrap(); - assert_eq!(dst.len(), MAX_MSG_LEN); - assert_eq!(dst.len() - encoded_msg.len(), VARINT_LEN); - } -} From 31e019acadc44e9a68b7d12e04cdcc629f93f2ef Mon Sep 17 00:00:00 2001 From: Anton Kaliaev Date: Mon, 19 Sep 2022 16:17:27 +0400 Subject: [PATCH 09/39] minor refactoring --- .../src/connection/poll_data_channel.rs | 64 ++++++++----------- 1 file changed, 27 insertions(+), 37 deletions(-) diff --git a/transports/webrtc/src/connection/poll_data_channel.rs b/transports/webrtc/src/connection/poll_data_channel.rs index e805875dd98..23e7d9a9ee7 100644 --- a/transports/webrtc/src/connection/poll_data_channel.rs +++ b/transports/webrtc/src/connection/poll_data_channel.rs @@ -192,26 +192,24 @@ impl PollDataChannel { pub fn set_read_buf_capacity(&mut self, capacity: usize) { self.io.get_mut().set_read_buf_capacity(capacity) } +} - fn io_poll_next( - io: &mut Framed, prost_codec::Codec>, - cx: &mut Context<'_>, - ) -> Poll, Option>)>>> { - match ready!(io.poll_next_unpin(cx)) - .transpose() - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))? - { - Some(Message { flag, message }) => { - let flag = flag - .map(|f| { - Flag::from_i32(f).ok_or(io::Error::new(io::ErrorKind::InvalidData, "")) - }) - .transpose()?; - - Poll::Ready(Ok(Some((flag, message)))) - } - None => Poll::Ready(Ok(None)), +fn io_poll_next( + io: &mut Framed, prost_codec::Codec>, + cx: &mut Context<'_>, +) -> Poll, Option>)>>> { + match ready!(io.poll_next_unpin(cx)) + .transpose() + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))? + { + Some(Message { flag, message }) => { + let flag = flag + .map(|f| Flag::from_i32(f).ok_or(io::Error::new(io::ErrorKind::InvalidData, ""))) + .transpose()?; + + Poll::Ready(Ok(Some((flag, message)))) } + None => Poll::Ready(Ok(None)), } } @@ -222,38 +220,30 @@ impl AsyncRead for PollDataChannel { buf: &mut [u8], ) -> Poll> { loop { - if let Some(read_buffer) = self.state.read_buffer_mut() { - if !read_buffer.is_empty() { - let n = std::cmp::min(read_buffer.len(), buf.len()); - let data = read_buffer.split_to(n); - buf[0..n].copy_from_slice(&data[..]); + if let Some(read_buffer) = self.state.read_buffer_mut() && !read_buffer.is_empty() { + let n = std::cmp::min(read_buffer.len(), buf.len()); + let data = read_buffer.split_to(n); + buf[0..n].copy_from_slice(&data[..]); - return Poll::Ready(Ok(n)); - } + return Poll::Ready(Ok(n)); } - let PollDataChannel { state, io } = &mut *self; + let Self { state, io } = &mut *self; let read_buffer = match state { - State::Open { - ref mut read_buffer, - } - | State::WriteClosed { - ref mut read_buffer, - } => read_buffer, + State::Open { read_buffer } | State::WriteClosed { read_buffer } => read_buffer, State::ReadClosed { read_buffer, .. } | State::ReadWriteClosed { read_buffer, .. } => { assert!(read_buffer.is_empty()); return Poll::Ready(Ok(0)); } State::ReadReset | State::ReadResetWriteClosed => { - // TODO: Is `""` valid? - return Poll::Ready(Err(io::Error::new(io::ErrorKind::ConnectionReset, ""))); + return Poll::Ready(Err(io::Error::from(io::ErrorKind::ConnectionReset))); } - State::Poisoned => todo!(), + State::Poisoned => unreachable!(), }; - match ready!(Self::io_poll_next(io, cx))? { + match ready!(io_poll_next(io, cx))? { Some((flag, message)) => { assert!(read_buffer.is_empty()); if let Some(message) = message { @@ -286,7 +276,7 @@ impl AsyncWrite for PollDataChannel { // TODO: In case AsyncRead::poll_read encountered an error or returned None earlier, we will poll the // underlying I/O resource once more. Is that allowed? How about introducing a state IoReadClosed? { - match Self::io_poll_next(&mut self.io, cx)? { + match io_poll_next(&mut self.io, cx)? { Poll::Ready(Some((Some(flag), message))) => { // Read side is closed. Discard any incoming messages. drop(message); From 6f57ed6363c8b1db63738895a98663573a96a7cc Mon Sep 17 00:00:00 2001 From: Anton Kaliaev Date: Mon, 19 Sep 2022 16:24:10 +0400 Subject: [PATCH 10/39] rename PollDataChannel to Substream --- transports/webrtc/src/connection.rs | 12 +- .../{poll_data_channel.rs => substream.rs} | 164 +++++++++--------- transports/webrtc/src/transport.rs | 6 +- 3 files changed, 91 insertions(+), 91 deletions(-) rename transports/webrtc/src/connection/{poll_data_channel.rs => substream.rs} (98%) diff --git a/transports/webrtc/src/connection.rs b/transports/webrtc/src/connection.rs index c664ec033d6..a807d0c09da 100644 --- a/transports/webrtc/src/connection.rs +++ b/transports/webrtc/src/connection.rs @@ -18,7 +18,7 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -mod poll_data_channel; +mod substream; use futures::{ channel::{ @@ -42,7 +42,7 @@ use std::{ }; use crate::error::Error; -pub(crate) use poll_data_channel::PollDataChannel; +pub(crate) use substream::Substream; const MAX_DATA_CHANNELS_IN_FLIGHT: usize = 10; @@ -57,7 +57,7 @@ pub struct Connection { incoming_data_channels_rx: mpsc::Receiver>, /// Temporary read buffer's capacity (equal for all data channels). - /// See [`PollDataChannel`] `read_buf_cap`. + /// See [`Substream`] `read_buf_cap`. read_buf_cap: Option, /// Future, which, once polled, will result in an outbound substream. @@ -149,7 +149,7 @@ impl Connection { } impl<'a> StreamMuxer for Connection { - type Substream = PollDataChannel; + type Substream = Substream; type Error = Error; fn poll_inbound( @@ -160,7 +160,7 @@ impl<'a> StreamMuxer for Connection { Some(detached) => { trace!("Incoming substream {}", detached.stream_identifier()); - let mut ch = PollDataChannel::new(detached); + let mut ch = Substream::new(detached); if let Some(cap) = self.read_buf_cap { ch.set_read_buf_capacity(cap); } @@ -213,7 +213,7 @@ impl<'a> StreamMuxer for Connection { match ready!(fut.as_mut().poll(cx)) { Ok(detached) => { - let mut ch = PollDataChannel::new(detached); + let mut ch = Substream::new(detached); if let Some(cap) = self.read_buf_cap { ch.set_read_buf_capacity(cap); } diff --git a/transports/webrtc/src/connection/poll_data_channel.rs b/transports/webrtc/src/connection/substream.rs similarity index 98% rename from transports/webrtc/src/connection/poll_data_channel.rs rename to transports/webrtc/src/connection/substream.rs index 23e7d9a9ee7..b065149e3bd 100644 --- a/transports/webrtc/src/connection/poll_data_channel.rs +++ b/transports/webrtc/src/connection/substream.rs @@ -44,90 +44,13 @@ const MAX_DATA_LEN: usize = MAX_MSG_LEN - VARINT_LEN - PROTO_OVERHEAD; /// A wrapper around [`RTCPollDataChannel`] implementing futures [`AsyncRead`] / [`AsyncWrite`]. // TODO // #[derive(Debug)] -pub struct PollDataChannel { +pub struct Substream { io: Framed, prost_codec::Codec>, state: State, } -enum State { - Open { read_buffer: Bytes }, - WriteClosed { read_buffer: Bytes }, - ReadClosed { read_buffer: Bytes }, - ReadWriteClosed { read_buffer: Bytes }, - ReadReset, - ReadResetWriteClosed, - Poisoned, -} - -impl State { - fn handle_flag(&mut self, flag: Flag) { - match (std::mem::replace(self, State::Poisoned), flag) { - // StopSending - ( - State::Open { read_buffer } | State::WriteClosed { read_buffer }, - Flag::StopSending, - ) => { - *self = State::WriteClosed { read_buffer }; - } - - ( - State::ReadClosed { read_buffer } | State::ReadWriteClosed { read_buffer }, - Flag::StopSending, - ) => { - *self = State::ReadWriteClosed { read_buffer }; - } - - (State::ReadReset | State::ReadResetWriteClosed, Flag::StopSending) => { - *self = State::ReadResetWriteClosed; - } - - // Fin - (State::Open { read_buffer } | State::ReadClosed { read_buffer }, Flag::Fin) => { - *self = State::ReadClosed { read_buffer }; - } - - ( - State::WriteClosed { read_buffer } | State::ReadWriteClosed { read_buffer }, - Flag::Fin, - ) => { - *self = State::ReadWriteClosed { read_buffer }; - } - - (State::ReadReset, Flag::Fin) => *self = State::ReadReset, - - (State::ReadResetWriteClosed, Flag::Fin) => *self = State::ReadResetWriteClosed, - - // Reset - (State::ReadClosed { .. } | State::ReadReset | State::Open { .. }, Flag::Reset) => { - *self = State::ReadReset - } - - ( - State::ReadWriteClosed { .. } - | State::WriteClosed { .. } - | State::ReadResetWriteClosed, - Flag::Reset, - ) => *self = State::ReadResetWriteClosed, - - (State::Poisoned, _) => unreachable!(), - } - } - - fn read_buffer_mut(&mut self) -> Option<&mut Bytes> { - match self { - State::Open { read_buffer } => Some(read_buffer), - State::WriteClosed { read_buffer } => Some(read_buffer), - State::ReadClosed { read_buffer } => Some(read_buffer), - State::ReadWriteClosed { read_buffer } => Some(read_buffer), - State::ReadReset => None, - State::ReadResetWriteClosed => None, - State::Poisoned => todo!(), - } - } -} - -impl PollDataChannel { - /// Constructs a new `PollDataChannel`. +impl Substream { + /// Constructs a new `Substream`. pub fn new(data_channel: Arc) -> Self { Self { io: Framed::new( @@ -213,7 +136,7 @@ fn io_poll_next( } } -impl AsyncRead for PollDataChannel { +impl AsyncRead for Substream { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -263,7 +186,7 @@ impl AsyncRead for PollDataChannel { } } -impl AsyncWrite for PollDataChannel { +impl AsyncWrite for Substream { fn poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -352,6 +275,83 @@ impl AsyncWrite for PollDataChannel { } } +enum State { + Open { read_buffer: Bytes }, + WriteClosed { read_buffer: Bytes }, + ReadClosed { read_buffer: Bytes }, + ReadWriteClosed { read_buffer: Bytes }, + ReadReset, + ReadResetWriteClosed, + Poisoned, +} + +impl State { + fn handle_flag(&mut self, flag: Flag) { + match (std::mem::replace(self, State::Poisoned), flag) { + // StopSending + ( + State::Open { read_buffer } | State::WriteClosed { read_buffer }, + Flag::StopSending, + ) => { + *self = State::WriteClosed { read_buffer }; + } + + ( + State::ReadClosed { read_buffer } | State::ReadWriteClosed { read_buffer }, + Flag::StopSending, + ) => { + *self = State::ReadWriteClosed { read_buffer }; + } + + (State::ReadReset | State::ReadResetWriteClosed, Flag::StopSending) => { + *self = State::ReadResetWriteClosed; + } + + // Fin + (State::Open { read_buffer } | State::ReadClosed { read_buffer }, Flag::Fin) => { + *self = State::ReadClosed { read_buffer }; + } + + ( + State::WriteClosed { read_buffer } | State::ReadWriteClosed { read_buffer }, + Flag::Fin, + ) => { + *self = State::ReadWriteClosed { read_buffer }; + } + + (State::ReadReset, Flag::Fin) => *self = State::ReadReset, + + (State::ReadResetWriteClosed, Flag::Fin) => *self = State::ReadResetWriteClosed, + + // Reset + (State::ReadClosed { .. } | State::ReadReset | State::Open { .. }, Flag::Reset) => { + *self = State::ReadReset + } + + ( + State::ReadWriteClosed { .. } + | State::WriteClosed { .. } + | State::ReadResetWriteClosed, + Flag::Reset, + ) => *self = State::ReadResetWriteClosed, + + (State::Poisoned, _) => unreachable!(), + } + } + + fn read_buffer_mut(&mut self) -> Option<&mut Bytes> { + match self { + State::Open { read_buffer } => Some(read_buffer), + State::WriteClosed { read_buffer } => Some(read_buffer), + State::ReadClosed { read_buffer } => Some(read_buffer), + State::ReadWriteClosed { read_buffer } => Some(read_buffer), + State::ReadReset => None, + State::ReadResetWriteClosed => None, + State::Poisoned => todo!(), + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/transports/webrtc/src/transport.rs b/transports/webrtc/src/transport.rs index 00358b3bea4..f79726d796c 100644 --- a/transports/webrtc/src/transport.rs +++ b/transports/webrtc/src/transport.rs @@ -51,7 +51,7 @@ use std::{ use crate::{ connection::Connection, - connection::PollDataChannel, + connection::Substream, error::Error, fingerprint::Fingerprint, in_addr::InAddr, @@ -191,7 +191,7 @@ impl Transport for WebRTCTransport { trace!("noise handshake with addr={}", remote); let peer_id = perform_noise_handshake_outbound( id_keys, - PollDataChannel::new(data_channel.clone()), + Substream::new(data_channel.clone()), our_fingerprint, remote_fingerprint, ) @@ -579,7 +579,7 @@ async fn upgrade( let remote_fingerprint = conn.get_remote_fingerprint().await; let peer_id = perform_noise_handshake_inbound( id_keys, - PollDataChannel::new(data_channel.clone()), + Substream::new(data_channel.clone()), our_fingerprint, remote_fingerprint, ) From 4a1d4d672095af8aa9764946a7e649bc6f1c4d87 Mon Sep 17 00:00:00 2001 From: Anton Kaliaev Date: Mon, 19 Sep 2022 17:32:25 +0400 Subject: [PATCH 11/39] add comments --- transports/webrtc/src/connection/substream.rs | 25 +++++++++++-------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/transports/webrtc/src/connection/substream.rs b/transports/webrtc/src/connection/substream.rs index b065149e3bd..09f1533788c 100644 --- a/transports/webrtc/src/connection/substream.rs +++ b/transports/webrtc/src/connection/substream.rs @@ -27,23 +27,29 @@ use tokio_util::compat::TokioAsyncReadCompatExt; use webrtc::data::data_channel::DataChannel; use webrtc::data::data_channel::PollDataChannel as RTCPollDataChannel; -use std::io; -use std::pin::Pin; -use std::sync::Arc; -use std::task::{Context, Poll}; +use std::{ + io, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; use crate::message_proto::message::Flag; use crate::message_proto::Message; -// TODO: Document +/// Maximum length of a message, in bytes. const MAX_MSG_LEN: usize = 16384; // 16kiB +/// Length of varint, in bytes. const VARINT_LEN: usize = 2; +/// Overhead of the protobuf encoding, in bytes. const PROTO_OVERHEAD: usize = 5; +/// Maximum length of data, in bytes. const MAX_DATA_LEN: usize = MAX_MSG_LEN - VARINT_LEN - PROTO_OVERHEAD; -/// A wrapper around [`RTCPollDataChannel`] implementing futures [`AsyncRead`] / [`AsyncWrite`]. -// TODO -// #[derive(Debug)] +/// Substream is a wrapper around [`RTCPollDataChannel`] implementing futures [`AsyncRead`] / +/// [`AsyncWrite`] and message framing (as per specification). +/// +/// #[derive(Debug)] pub struct Substream { io: Framed, prost_codec::Codec>, state: State, @@ -55,8 +61,7 @@ impl Substream { Self { io: Framed::new( RTCPollDataChannel::new(data_channel).compat(), - // TODO: Fix MAX - prost_codec::Codec::new(usize::MAX), + prost_codec::Codec::new(MAX_MSG_LEN), ), state: State::Open { read_buffer: Default::default(), From c6c5a963fe6dedcf515c85f385bdf2bc2f418d45 Mon Sep 17 00:00:00 2001 From: Anton Kaliaev Date: Tue, 20 Sep 2022 09:32:08 +0400 Subject: [PATCH 12/39] add debug to handle_flag --- transports/webrtc/src/connection/substream.rs | 35 ++++++++++++++++--- 1 file changed, 30 insertions(+), 5 deletions(-) diff --git a/transports/webrtc/src/connection/substream.rs b/transports/webrtc/src/connection/substream.rs index 09f1533788c..ca35dc48dab 100644 --- a/transports/webrtc/src/connection/substream.rs +++ b/transports/webrtc/src/connection/substream.rs @@ -28,7 +28,7 @@ use webrtc::data::data_channel::DataChannel; use webrtc::data::data_channel::PollDataChannel as RTCPollDataChannel; use std::{ - io, + fmt, io, pin::Pin, sync::Arc, task::{Context, Poll}, @@ -156,6 +156,7 @@ impl AsyncRead for Substream { return Poll::Ready(Ok(n)); } + let substream_id = self.stream_identifier(); let Self { state, io } = &mut *self; let read_buffer = match state { @@ -179,11 +180,11 @@ impl AsyncRead for Substream { } if let Some(flag) = flag { - self.state.handle_flag(flag) + self.state.handle_flag(flag, substream_id) }; } None => { - self.state.handle_flag(Flag::Fin); + self.state.handle_flag(Flag::Fin, substream_id); return Poll::Ready(Ok(0)); } } @@ -197,6 +198,7 @@ impl AsyncWrite for Substream { cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { + let substream_id = self.stream_identifier(); // Handle flags iff read side closed. loop { match self.state { @@ -209,7 +211,7 @@ impl AsyncWrite for Substream { // Read side is closed. Discard any incoming messages. drop(message); // But still handle flags, e.g. a `Flag::StopSending`. - self.state.handle_flag(flag) + self.state.handle_flag(flag, substream_id) } Poll::Ready(Some((None, message))) => drop(message), Poll::Ready(None) | Poll::Pending => break, @@ -291,7 +293,8 @@ enum State { } impl State { - fn handle_flag(&mut self, flag: Flag) { + fn handle_flag(&mut self, flag: Flag, substream_id: u16) { + let old_state = format!("{}", self); match (std::mem::replace(self, State::Poisoned), flag) { // StopSending ( @@ -342,6 +345,14 @@ impl State { (State::Poisoned, _) => unreachable!(), } + + log::debug!( + "substream={}: got flag {:?}, moved from {} to {}", + substream_id, + flag, + old_state, + *self + ); } fn read_buffer_mut(&mut self) -> Option<&mut Bytes> { @@ -357,6 +368,20 @@ impl State { } } +impl fmt::Display for State { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + State::Open { .. } => write!(f, "Open"), + State::WriteClosed { .. } => write!(f, "WriteClosed"), + State::ReadClosed { .. } => write!(f, "ReadClosed"), + State::ReadWriteClosed { .. } => write!(f, "ReadWriteClosed"), + State::ReadReset => write!(f, "ReadReset"), + State::ReadResetWriteClosed => write!(f, "ReadResetWriteClosed"), + State::Poisoned => write!(f, "Poisoned"), + } + } +} + #[cfg(test)] mod tests { use super::*; From 1b0b671ab086950d1e7b00de9ee8200da0b7be57 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Mon, 26 Sep 2022 18:51:16 +1000 Subject: [PATCH 13/39] Create noise-prologue from server + client FP in fixed order --- transports/webrtc/src/transport.rs | 62 ++++++++++++++---------------- 1 file changed, 29 insertions(+), 33 deletions(-) diff --git a/transports/webrtc/src/transport.rs b/transports/webrtc/src/transport.rs index 775a703090b..bf5f25c57f3 100644 --- a/transports/webrtc/src/transport.rs +++ b/transports/webrtc/src/transport.rs @@ -35,7 +35,6 @@ use libp2p_core::{ }; use libp2p_noise::{Keypair, NoiseConfig, X25519Spec}; use log::{debug, trace}; -use multihash::Multihash; use tokio_crate::net::UdpSocket; use webrtc::ice::udp_mux::UDPMux; use webrtc::peer_connection::certificate::RTCCertificate; @@ -162,7 +161,7 @@ impl Transport for WebRTCTransport { trace!("dialing addr={}", remote); let config = self.config.clone(); - let our_fingerprint = self.config.fingerprint_of_first_certificate(); + let client_fingerprint = self.config.fingerprint_of_first_certificate(); let id_keys = self.id_keys.clone(); let first_listener = self @@ -175,14 +174,14 @@ impl Transport for WebRTCTransport { // [`Transport::dial`] should do no work unless the returned [`Future`] is polled. Thus // do the `set_remote_description` call within the [`Future`]. Ok(async move { - let remote_fingerprint = fingerprint_from_addr(&addr) + let server_fingerprint = fingerprint_from_addr(&addr) .ok_or_else(|| Error::InvalidMultiaddr(addr.clone()))?; let conn = WebRTCConnection::connect( sock_addr, config.into_inner(), udp_mux, - &remote_fingerprint, + &server_fingerprint, ) .await?; @@ -193,8 +192,8 @@ impl Transport for WebRTCTransport { let peer_id = perform_noise_handshake_outbound( id_keys, PollDataChannel::new(data_channel.clone()), - our_fingerprint, - remote_fingerprint, + client_fingerprint, + server_fingerprint, ) .await?; @@ -500,8 +499,8 @@ fn multiaddr_to_socketaddr(addr: &Multiaddr) -> Option { async fn perform_noise_handshake_outbound( id_keys: identity::Keypair, poll_data_channel: T, - our_fingerprint: Fingerprint, - remote_fingerprint: Fingerprint, + client_fingerprint: Fingerprint, + server_fingerprint: Fingerprint, ) -> Result where T: AsyncRead + AsyncWrite + Unpin + Send + 'static, @@ -509,8 +508,8 @@ where let dh_keys = Keypair::::new() .into_authentic(&id_keys) .unwrap(); - let noise = - NoiseConfig::xx(dh_keys).with_prologue(noise_prologue(our_fingerprint, remote_fingerprint)); + let noise = NoiseConfig::xx(dh_keys) + .with_prologue(noise_prologue(client_fingerprint, server_fingerprint)); let info = noise.protocol_info().next().unwrap(); let (peer_id, _noise_io) = noise .into_authenticated() @@ -529,13 +528,13 @@ async fn upgrade( ) -> Result<(PeerId, Connection), Error> { trace!("upgrading addr={} (ufrag={})", socket_addr, ufrag); - let our_fingerprint = config.fingerprint_of_first_certificate(); + let server_fingerprint = config.fingerprint_of_first_certificate(); let conn = WebRTCConnection::accept( socket_addr, config.into_inner(), udp_mux, - &our_fingerprint, + &server_fingerprint, &ufrag, ) .await?; @@ -548,12 +547,13 @@ async fn upgrade( socket_addr, ufrag ); - let remote_fingerprint = conn.get_remote_fingerprint().await; + let client_fingerprint = conn.get_remote_fingerprint().await; + let peer_id = perform_noise_handshake_inbound( id_keys, PollDataChannel::new(data_channel.clone()), - our_fingerprint, - remote_fingerprint, + client_fingerprint, + server_fingerprint, ) .await?; @@ -574,8 +574,8 @@ async fn upgrade( async fn perform_noise_handshake_inbound( id_keys: identity::Keypair, poll_data_channel: T, - our_fingerprint: Fingerprint, - remote_fingerprint: Fingerprint, + client_fingerprint: Fingerprint, + server_fingerprint: Fingerprint, ) -> Result where T: AsyncRead + AsyncWrite + Unpin + Send + 'static, @@ -583,8 +583,8 @@ where let dh_keys = Keypair::::new() .into_authentic(&id_keys) .unwrap(); - let noise = - NoiseConfig::xx(dh_keys).with_prologue(noise_prologue(our_fingerprint, remote_fingerprint)); + let noise = NoiseConfig::xx(dh_keys) + .with_prologue(noise_prologue(client_fingerprint, server_fingerprint)); let info = noise.protocol_info().next().unwrap(); let (peer_id, _noise_io) = noise .into_authenticated() @@ -593,18 +593,17 @@ where Ok(peer_id) } -fn noise_prologue(our_fingerprint: Fingerprint, remote_fingerprint: Fingerprint) -> Vec { - let (a, b): (Multihash, Multihash) = ( - our_fingerprint.to_multi_hash(), - remote_fingerprint.to_multi_hash(), - ); - let (a, b) = (a.to_bytes(), b.to_bytes()); - let (first, second) = if a < b { (a, b) } else { (b, a) }; +fn noise_prologue(client: Fingerprint, server: Fingerprint) -> Vec { + let server = server.to_multi_hash().to_bytes(); + let client = client.to_multi_hash().to_bytes(); + const PREFIX: &[u8] = b"libp2p-webrtc-noise:"; - let mut out = Vec::with_capacity(PREFIX.len() + first.len() + second.len()); + + let mut out = Vec::with_capacity(PREFIX.len() + server.len() + client.len()); out.extend_from_slice(PREFIX); - out.extend_from_slice(&first); - out.extend_from_slice(&second); + out.extend_from_slice(&server); + out.extend_from_slice(&client); + out } @@ -633,10 +632,7 @@ mod tests { let prologue2 = noise_prologue(b, a); assert_eq!(hex::encode(&prologue1), "6c69627032702d7765627274632d6e6f6973653a122030fc9f469c207419dfdd0aab5f27a86c973c94e40548db9375cca2e915973b9912203e79af40d6059617a0d83b83a52ce73b0c1f37a72c6043ad2969e2351bdca870"); - assert_eq!( - prologue1, prologue2, - "order of fingerprints does not matter" - ); + assert_eq!(hex::encode(&prologue2), "6c69627032702d7765627274632d6e6f6973653a12203e79af40d6059617a0d83b83a52ce73b0c1f37a72c6043ad2969e2351bdca870122030fc9f469c207419dfdd0aab5f27a86c973c94e40548db9375cca2e915973b99"); } #[test] From aa38c81cffecd1069fc30a4138466d9553ad5fdf Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Tue, 27 Sep 2022 12:48:01 +1000 Subject: [PATCH 14/39] Make `substream` a top-level module --- transports/webrtc/src/connection.rs | 4 +--- transports/webrtc/src/lib.rs | 1 + transports/webrtc/src/{connection => }/substream.rs | 0 transports/webrtc/src/transport.rs | 2 +- 4 files changed, 3 insertions(+), 4 deletions(-) rename transports/webrtc/src/{connection => }/substream.rs (100%) diff --git a/transports/webrtc/src/connection.rs b/transports/webrtc/src/connection.rs index a807d0c09da..b3e351c79d2 100644 --- a/transports/webrtc/src/connection.rs +++ b/transports/webrtc/src/connection.rs @@ -18,8 +18,6 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -mod substream; - use futures::{ channel::{ mpsc, @@ -42,7 +40,7 @@ use std::{ }; use crate::error::Error; -pub(crate) use substream::Substream; +use crate::substream::Substream; const MAX_DATA_CHANNELS_IN_FLIGHT: usize = 10; diff --git a/transports/webrtc/src/lib.rs b/transports/webrtc/src/lib.rs index 4d5615353b4..463e3ba7a6a 100644 --- a/transports/webrtc/src/lib.rs +++ b/transports/webrtc/src/lib.rs @@ -88,6 +88,7 @@ mod fingerprint; mod in_addr; mod req_res_chan; mod sdp; +mod substream; mod udp_mux; mod webrtc_connection; diff --git a/transports/webrtc/src/connection/substream.rs b/transports/webrtc/src/substream.rs similarity index 100% rename from transports/webrtc/src/connection/substream.rs rename to transports/webrtc/src/substream.rs diff --git a/transports/webrtc/src/transport.rs b/transports/webrtc/src/transport.rs index f79726d796c..e7f564002ca 100644 --- a/transports/webrtc/src/transport.rs +++ b/transports/webrtc/src/transport.rs @@ -51,10 +51,10 @@ use std::{ use crate::{ connection::Connection, - connection::Substream, error::Error, fingerprint::Fingerprint, in_addr::InAddr, + substream::Substream, udp_mux::{UDPMuxEvent, UDPMuxNewAddr}, webrtc_connection::WebRTCConnection, }; From d0e918bc60617da939d59edaa765a974dd3fc891 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Tue, 27 Sep 2022 12:52:58 +1000 Subject: [PATCH 15/39] Replace nightly feature with refactoring --- transports/webrtc/src/substream.rs | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/transports/webrtc/src/substream.rs b/transports/webrtc/src/substream.rs index ca35dc48dab..0648a169fc8 100644 --- a/transports/webrtc/src/substream.rs +++ b/transports/webrtc/src/substream.rs @@ -148,7 +148,7 @@ impl AsyncRead for Substream { buf: &mut [u8], ) -> Poll> { loop { - if let Some(read_buffer) = self.state.read_buffer_mut() && !read_buffer.is_empty() { + if let Some(read_buffer) = self.state.non_empty_read_buffer_mut() { let n = std::cmp::min(read_buffer.len(), buf.len()); let data = read_buffer.split_to(n); buf[0..n].copy_from_slice(&data[..]); @@ -355,15 +355,19 @@ impl State { ); } - fn read_buffer_mut(&mut self) -> Option<&mut Bytes> { + /// Returns a reference to the underlying buffer if possible and the buffer is not empty. + fn non_empty_read_buffer_mut(&mut self) -> Option<&mut Bytes> { match self { - State::Open { read_buffer } => Some(read_buffer), - State::WriteClosed { read_buffer } => Some(read_buffer), - State::ReadClosed { read_buffer } => Some(read_buffer), - State::ReadWriteClosed { read_buffer } => Some(read_buffer), - State::ReadReset => None, - State::ReadResetWriteClosed => None, - State::Poisoned => todo!(), + State::Open { read_buffer } + | State::WriteClosed { read_buffer } + | State::ReadClosed { read_buffer } + | State::ReadWriteClosed { read_buffer } + if !read_buffer.is_empty() => + { + Some(read_buffer) + } + State::Poisoned => unreachable!(), + _ => None, } } } From 171c613309a4f8379516e251575163a6642db53a Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Tue, 27 Sep 2022 12:54:21 +1000 Subject: [PATCH 16/39] Remove use of import rename --- transports/webrtc/src/substream.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/transports/webrtc/src/substream.rs b/transports/webrtc/src/substream.rs index 0648a169fc8..5017c2a4701 100644 --- a/transports/webrtc/src/substream.rs +++ b/transports/webrtc/src/substream.rs @@ -25,7 +25,7 @@ use futures::ready; use tokio_util::compat::Compat; use tokio_util::compat::TokioAsyncReadCompatExt; use webrtc::data::data_channel::DataChannel; -use webrtc::data::data_channel::PollDataChannel as RTCPollDataChannel; +use webrtc::data::data_channel::PollDataChannel; use std::{ fmt, io, @@ -51,7 +51,7 @@ const MAX_DATA_LEN: usize = MAX_MSG_LEN - VARINT_LEN - PROTO_OVERHEAD; /// /// #[derive(Debug)] pub struct Substream { - io: Framed, prost_codec::Codec>, + io: Framed, prost_codec::Codec>, state: State, } @@ -60,7 +60,7 @@ impl Substream { pub fn new(data_channel: Arc) -> Self { Self { io: Framed::new( - RTCPollDataChannel::new(data_channel).compat(), + PollDataChannel::new(data_channel).compat(), prost_codec::Codec::new(MAX_MSG_LEN), ), state: State::Open { @@ -70,12 +70,12 @@ impl Substream { } /// Get back the inner data_channel. - pub fn into_inner(self) -> RTCPollDataChannel { + pub fn into_inner(self) -> PollDataChannel { self.io.into_inner().into_inner() } /// Obtain a clone of the inner data_channel. - pub fn clone_inner(&self) -> RTCPollDataChannel { + pub fn clone_inner(&self) -> PollDataChannel { self.io.get_ref().clone() } @@ -123,7 +123,7 @@ impl Substream { } fn io_poll_next( - io: &mut Framed, prost_codec::Codec>, + io: &mut Framed, prost_codec::Codec>, cx: &mut Context<'_>, ) -> Poll, Option>)>>> { match ready!(io.poll_next_unpin(cx)) From 64026650c984cef2c8046fef3f64ed2405c23be4 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Tue, 27 Sep 2022 13:09:36 +1000 Subject: [PATCH 17/39] Remove unused public API --- transports/webrtc/src/substream.rs | 42 ------------------------------ 1 file changed, 42 deletions(-) diff --git a/transports/webrtc/src/substream.rs b/transports/webrtc/src/substream.rs index 5017c2a4701..20368155b98 100644 --- a/transports/webrtc/src/substream.rs +++ b/transports/webrtc/src/substream.rs @@ -69,53 +69,11 @@ impl Substream { } } - /// Get back the inner data_channel. - pub fn into_inner(self) -> PollDataChannel { - self.io.into_inner().into_inner() - } - - /// Obtain a clone of the inner data_channel. - pub fn clone_inner(&self) -> PollDataChannel { - self.io.get_ref().clone() - } - - /// MessagesSent returns the number of messages sent - pub fn messages_sent(&self) -> usize { - self.io.get_ref().messages_sent() - } - - /// MessagesReceived returns the number of messages received - pub fn messages_received(&self) -> usize { - self.io.get_ref().messages_received() - } - - /// BytesSent returns the number of bytes sent - pub fn bytes_sent(&self) -> usize { - self.io.get_ref().bytes_sent() - } - - /// BytesReceived returns the number of bytes received - pub fn bytes_received(&self) -> usize { - self.io.get_ref().bytes_received() - } - /// StreamIdentifier returns the Stream identifier associated to the stream. pub fn stream_identifier(&self) -> u16 { self.io.get_ref().stream_identifier() } - /// BufferedAmount returns the number of bytes of data currently queued to be - /// sent over this stream. - pub fn buffered_amount(&self) -> usize { - self.io.get_ref().buffered_amount() - } - - /// BufferedAmountLowThreshold returns the number of bytes of buffered outgoing - /// data that is considered "low." Defaults to 0. - pub fn buffered_amount_low_threshold(&self) -> usize { - self.io.get_ref().buffered_amount_low_threshold() - } - /// Set the capacity of the temporary read buffer (default: 8192). pub fn set_read_buf_capacity(&mut self, capacity: usize) { self.io.get_mut().set_read_buf_capacity(capacity) From d53769671239a135c4a98b883be9bb6c8301220f Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Tue, 27 Sep 2022 13:13:10 +1000 Subject: [PATCH 18/39] Make sure we don't construct substreams outside of this crate --- transports/webrtc/src/substream.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transports/webrtc/src/substream.rs b/transports/webrtc/src/substream.rs index 20368155b98..b3c5cf70282 100644 --- a/transports/webrtc/src/substream.rs +++ b/transports/webrtc/src/substream.rs @@ -57,7 +57,7 @@ pub struct Substream { impl Substream { /// Constructs a new `Substream`. - pub fn new(data_channel: Arc) -> Self { + pub(crate) fn new(data_channel: Arc) -> Self { Self { io: Framed::new( PollDataChannel::new(data_channel).compat(), From c8c244659a4ef8b8bcfc49fe91107947ea66e239 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Tue, 27 Sep 2022 13:14:34 +1000 Subject: [PATCH 19/39] Don't expose public APIs for temporary workarounds --- transports/webrtc/src/connection.rs | 23 ++--------------------- transports/webrtc/src/substream.rs | 16 +++++++--------- transports/webrtc/src/transport.rs | 13 ++----------- 3 files changed, 11 insertions(+), 41 deletions(-) diff --git a/transports/webrtc/src/connection.rs b/transports/webrtc/src/connection.rs index 93f7f9d3b15..c0468654600 100644 --- a/transports/webrtc/src/connection.rs +++ b/transports/webrtc/src/connection.rs @@ -54,10 +54,6 @@ pub struct Connection { /// Channel onto which incoming data channels are put. incoming_data_channels_rx: mpsc::Receiver>, - /// Temporary read buffer's capacity (equal for all data channels). - /// See [`Substream`] `read_buf_cap`. - read_buf_cap: Option, - /// Future, which, once polled, will result in an outbound substream. outbound_fut: Option, Error>>>, @@ -77,17 +73,11 @@ impl Connection { Self { peer_conn: Arc::new(FutMutex::new(rtc_conn)), incoming_data_channels_rx: data_channel_rx, - read_buf_cap: None, outbound_fut: None, close_fut: None, } } - /// Set the capacity of a data channel's temporary read buffer (equal for all data channels; default: 8192). - pub fn set_data_channels_read_buf_capacity(&mut self, cap: usize) { - self.read_buf_cap = Some(cap); - } - /// Registers a handler for incoming data channels. async fn register_incoming_data_channels_handler( rtc_conn: &RTCPeerConnection, @@ -157,12 +147,7 @@ impl StreamMuxer for Connection { Some(detached) => { trace!("Incoming substream {}", detached.stream_identifier()); - let mut ch = Substream::new(detached); - if let Some(cap) = self.read_buf_cap { - ch.set_read_buf_capacity(cap); - } - - Poll::Ready(Ok(ch)) + Poll::Ready(Ok(Substream::new(detached))) } None => Poll::Ready(Err(Error::InternalError( "incoming_data_channels_rx is closed (no messages left)".to_string(), @@ -210,12 +195,8 @@ impl StreamMuxer for Connection { match ready!(fut.as_mut().poll(cx)) { Ok(detached) => { - let mut ch = Substream::new(detached); - if let Some(cap) = self.read_buf_cap { - ch.set_read_buf_capacity(cap); - } self.outbound_fut = None; - Poll::Ready(Ok(ch)) + Poll::Ready(Ok(Substream::new(detached))) } Err(e) => { self.outbound_fut = None; diff --git a/transports/webrtc/src/substream.rs b/transports/webrtc/src/substream.rs index b3c5cf70282..d52e54dc317 100644 --- a/transports/webrtc/src/substream.rs +++ b/transports/webrtc/src/substream.rs @@ -58,11 +58,14 @@ pub struct Substream { impl Substream { /// Constructs a new `Substream`. pub(crate) fn new(data_channel: Arc) -> Self { + let mut inner = PollDataChannel::new(data_channel); + + // TODO: default buffer size is too small to fit some messages. Possibly remove once + // https://github.com/webrtc-rs/webrtc/issues/273 is fixed. + inner.set_read_buf_capacity(8192 * 10); + Self { - io: Framed::new( - PollDataChannel::new(data_channel).compat(), - prost_codec::Codec::new(MAX_MSG_LEN), - ), + io: Framed::new(inner.compat(), prost_codec::Codec::new(MAX_MSG_LEN)), state: State::Open { read_buffer: Default::default(), }, @@ -73,11 +76,6 @@ impl Substream { pub fn stream_identifier(&self) -> u16 { self.io.get_ref().stream_identifier() } - - /// Set the capacity of the temporary read buffer (default: 8192). - pub fn set_read_buf_capacity(&mut self, capacity: usize) { - self.io.get_mut().set_read_buf_capacity(capacity) - } } fn io_poll_next( diff --git a/transports/webrtc/src/transport.rs b/transports/webrtc/src/transport.rs index 575944694f5..507c200b308 100644 --- a/transports/webrtc/src/transport.rs +++ b/transports/webrtc/src/transport.rs @@ -212,11 +212,7 @@ impl Transport for WebRTCTransport { .await .map_err(|e| Error::WebRTC(webrtc::Error::Data(e)))?; - let mut c = Connection::new(conn.into_inner()).await; - // TODO: default buffer size is too small to fit some messages. Possibly remove once - // https://github.com/webrtc-rs/sctp/issues/28 is fixed. - c.set_data_channels_read_buf_capacity(8192 * 10); - Ok((peer_id, c)) + Ok((peer_id, Connection::new(conn.into_inner()).await)) } .boxed()) } @@ -563,12 +559,7 @@ async fn upgrade( .await .map_err(|e| Error::WebRTC(webrtc::Error::Data(e)))?; - let mut c = Connection::new(conn.into_inner()).await; - // TODO: default buffer size is too small to fit some messages. Possibly remove once - // https://github.com/webrtc-rs/sctp/issues/28 is fixed. - c.set_data_channels_read_buf_capacity(8192 * 10); - - Ok((peer_id, c)) + Ok((peer_id, Connection::new(conn.into_inner()).await)) } async fn perform_noise_handshake_inbound( From d58f219c1090c836c6f5e4b455a1474cc2851d23 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Tue, 27 Sep 2022 13:17:23 +1000 Subject: [PATCH 20/39] Remove pub where not necessary --- transports/webrtc/src/substream.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/transports/webrtc/src/substream.rs b/transports/webrtc/src/substream.rs index d52e54dc317..b37e5f216be 100644 --- a/transports/webrtc/src/substream.rs +++ b/transports/webrtc/src/substream.rs @@ -72,8 +72,7 @@ impl Substream { } } - /// StreamIdentifier returns the Stream identifier associated to the stream. - pub fn stream_identifier(&self) -> u16 { + fn stream_identifier(&self) -> u16 { self.io.get_ref().stream_identifier() } } From eb09d3651c88fa7f1cf7c15b1a86327ae6e98df1 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Tue, 27 Sep 2022 13:18:27 +1000 Subject: [PATCH 21/39] Remove utilities below usage --- transports/webrtc/src/substream.rs | 38 +++++++++++++++--------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/transports/webrtc/src/substream.rs b/transports/webrtc/src/substream.rs index b37e5f216be..b93fda5d02d 100644 --- a/transports/webrtc/src/substream.rs +++ b/transports/webrtc/src/substream.rs @@ -77,25 +77,6 @@ impl Substream { } } -fn io_poll_next( - io: &mut Framed, prost_codec::Codec>, - cx: &mut Context<'_>, -) -> Poll, Option>)>>> { - match ready!(io.poll_next_unpin(cx)) - .transpose() - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))? - { - Some(Message { flag, message }) => { - let flag = flag - .map(|f| Flag::from_i32(f).ok_or(io::Error::new(io::ErrorKind::InvalidData, ""))) - .transpose()?; - - Poll::Ready(Ok(Some((flag, message)))) - } - None => Poll::Ready(Ok(None)), - } -} - impl AsyncRead for Substream { fn poll_read( mut self: Pin<&mut Self>, @@ -237,6 +218,25 @@ impl AsyncWrite for Substream { } } +fn io_poll_next( + io: &mut Framed, prost_codec::Codec>, + cx: &mut Context<'_>, +) -> Poll, Option>)>>> { + match ready!(io.poll_next_unpin(cx)) + .transpose() + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))? + { + Some(Message { flag, message }) => { + let flag = flag + .map(|f| Flag::from_i32(f).ok_or(io::Error::new(io::ErrorKind::InvalidData, ""))) + .transpose()?; + + Poll::Ready(Ok(Some((flag, message)))) + } + None => Poll::Ready(Ok(None)), + } +} + enum State { Open { read_buffer: Bytes }, WriteClosed { read_buffer: Bytes }, From 99af2a1e6ad5822befe99e049daa0d17de433a49 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Tue, 27 Sep 2022 13:19:50 +1000 Subject: [PATCH 22/39] Remove stale derive --- transports/webrtc/src/substream.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/transports/webrtc/src/substream.rs b/transports/webrtc/src/substream.rs index b93fda5d02d..8059048b1d3 100644 --- a/transports/webrtc/src/substream.rs +++ b/transports/webrtc/src/substream.rs @@ -48,8 +48,6 @@ const MAX_DATA_LEN: usize = MAX_MSG_LEN - VARINT_LEN - PROTO_OVERHEAD; /// Substream is a wrapper around [`RTCPollDataChannel`] implementing futures [`AsyncRead`] / /// [`AsyncWrite`] and message framing (as per specification). -/// -/// #[derive(Debug)] pub struct Substream { io: Framed, prost_codec::Codec>, state: State, From 2d85ab451ad301b23febeac9284c9e762068c7a4 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Tue, 27 Sep 2022 13:21:43 +1000 Subject: [PATCH 23/39] Update docs --- transports/webrtc/src/substream.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/transports/webrtc/src/substream.rs b/transports/webrtc/src/substream.rs index 8059048b1d3..329fde7b9d9 100644 --- a/transports/webrtc/src/substream.rs +++ b/transports/webrtc/src/substream.rs @@ -46,8 +46,10 @@ const PROTO_OVERHEAD: usize = 5; /// Maximum length of data, in bytes. const MAX_DATA_LEN: usize = MAX_MSG_LEN - VARINT_LEN - PROTO_OVERHEAD; -/// Substream is a wrapper around [`RTCPollDataChannel`] implementing futures [`AsyncRead`] / -/// [`AsyncWrite`] and message framing (as per specification). +/// A substream on top of a WebRTC data channel. +/// +/// To be a proper libp2p substream, we need to implement [`AsyncRead`] and [`AsyncWrite`] as well +/// as support a half-closed state which we do by framing messages in a protobuf envelope. pub struct Substream { io: Framed, prost_codec::Codec>, state: State, From 6e2aeb17ff1e3e6c107d0145978c229568c6682e Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Tue, 27 Sep 2022 14:10:12 +1000 Subject: [PATCH 24/39] Fix clippy warnings --- transports/webrtc/src/lib.rs | 2 ++ transports/webrtc/src/substream.rs | 32 +++++++++++++----------------- 2 files changed, 16 insertions(+), 18 deletions(-) diff --git a/transports/webrtc/src/lib.rs b/transports/webrtc/src/lib.rs index a070015fdb7..a63d0fa624f 100644 --- a/transports/webrtc/src/lib.rs +++ b/transports/webrtc/src/lib.rs @@ -90,6 +90,8 @@ mod transport; mod udp_mux; mod webrtc_connection; mod message_proto { + #![allow(clippy::derive_partial_eq_without_eq)] + include!(concat!(env!("OUT_DIR"), "/webrtc.pb.rs")); } diff --git a/transports/webrtc/src/substream.rs b/transports/webrtc/src/substream.rs index 329fde7b9d9..c78196e2fa9 100644 --- a/transports/webrtc/src/substream.rs +++ b/transports/webrtc/src/substream.rs @@ -136,24 +136,18 @@ impl AsyncWrite for Substream { ) -> Poll> { let substream_id = self.stream_identifier(); // Handle flags iff read side closed. - loop { - match self.state { - State::ReadClosed { .. } | State::ReadReset => - // TODO: In case AsyncRead::poll_read encountered an error or returned None earlier, we will poll the - // underlying I/O resource once more. Is that allowed? How about introducing a state IoReadClosed? - { - match io_poll_next(&mut self.io, cx)? { - Poll::Ready(Some((Some(flag), message))) => { - // Read side is closed. Discard any incoming messages. - drop(message); - // But still handle flags, e.g. a `Flag::StopSending`. - self.state.handle_flag(flag, substream_id) - } - Poll::Ready(Some((None, message))) => drop(message), - Poll::Ready(None) | Poll::Pending => break, - } + while let State::ReadClosed { .. } | State::ReadReset = self.state { + // TODO: In case AsyncRead::poll_read encountered an error or returned None earlier, we will poll the + // underlying I/O resource once more. Is that allowed? How about introducing a state IoReadClosed? + match io_poll_next(&mut self.io, cx)? { + Poll::Ready(Some((Some(flag), message))) => { + // Read side is closed. Discard any incoming messages. + drop(message); + // But still handle flags, e.g. a `Flag::StopSending`. + self.state.handle_flag(flag, substream_id) } - _ => break, + Poll::Ready(Some((None, message))) => drop(message), + Poll::Ready(None) | Poll::Pending => break, } } @@ -228,7 +222,9 @@ fn io_poll_next( { Some(Message { flag, message }) => { let flag = flag - .map(|f| Flag::from_i32(f).ok_or(io::Error::new(io::ErrorKind::InvalidData, ""))) + .map(|f| { + Flag::from_i32(f).ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "")) + }) .transpose()?; Poll::Ready(Ok(Some((flag, message)))) From a6b2aacac87f284434560ad9015be16cc278d112 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Fri, 7 Oct 2022 10:58:04 +1100 Subject: [PATCH 25/39] Revert "Create noise-prologue from server + client FP in fixed order" This reverts commit 1b0b671ab086950d1e7b00de9ee8200da0b7be57. --- transports/webrtc/src/transport.rs | 62 ++++++++++++++++-------------- 1 file changed, 33 insertions(+), 29 deletions(-) diff --git a/transports/webrtc/src/transport.rs b/transports/webrtc/src/transport.rs index 507c200b308..29078271b05 100644 --- a/transports/webrtc/src/transport.rs +++ b/transports/webrtc/src/transport.rs @@ -35,6 +35,7 @@ use libp2p_core::{ }; use libp2p_noise::{Keypair, NoiseConfig, X25519Spec}; use log::{debug, trace}; +use multihash::Multihash; use tokio_crate::net::UdpSocket; use webrtc::ice::udp_mux::UDPMux; use webrtc::peer_connection::certificate::RTCCertificate; @@ -161,7 +162,7 @@ impl Transport for WebRTCTransport { trace!("dialing addr={}", remote); let config = self.config.clone(); - let client_fingerprint = self.config.fingerprint_of_first_certificate(); + let our_fingerprint = self.config.fingerprint_of_first_certificate(); let id_keys = self.id_keys.clone(); let first_listener = self @@ -174,14 +175,14 @@ impl Transport for WebRTCTransport { // [`Transport::dial`] should do no work unless the returned [`Future`] is polled. Thus // do the `set_remote_description` call within the [`Future`]. Ok(async move { - let server_fingerprint = fingerprint_from_addr(&addr) + let remote_fingerprint = fingerprint_from_addr(&addr) .ok_or_else(|| Error::InvalidMultiaddr(addr.clone()))?; let conn = WebRTCConnection::connect( sock_addr, config.into_inner(), udp_mux, - &server_fingerprint, + &remote_fingerprint, ) .await?; @@ -192,8 +193,8 @@ impl Transport for WebRTCTransport { let peer_id = perform_noise_handshake_outbound( id_keys, Substream::new(data_channel.clone()), - client_fingerprint, - server_fingerprint, + our_fingerprint, + remote_fingerprint, ) .await?; @@ -495,8 +496,8 @@ fn multiaddr_to_socketaddr(addr: &Multiaddr) -> Option { async fn perform_noise_handshake_outbound( id_keys: identity::Keypair, poll_data_channel: T, - client_fingerprint: Fingerprint, - server_fingerprint: Fingerprint, + our_fingerprint: Fingerprint, + remote_fingerprint: Fingerprint, ) -> Result where T: AsyncRead + AsyncWrite + Unpin + Send + 'static, @@ -504,8 +505,8 @@ where let dh_keys = Keypair::::new() .into_authentic(&id_keys) .unwrap(); - let noise = NoiseConfig::xx(dh_keys) - .with_prologue(noise_prologue(client_fingerprint, server_fingerprint)); + let noise = + NoiseConfig::xx(dh_keys).with_prologue(noise_prologue(our_fingerprint, remote_fingerprint)); let info = noise.protocol_info().next().unwrap(); let (peer_id, _noise_io) = noise .into_authenticated() @@ -524,13 +525,13 @@ async fn upgrade( ) -> Result<(PeerId, Connection), Error> { trace!("upgrading addr={} (ufrag={})", socket_addr, ufrag); - let server_fingerprint = config.fingerprint_of_first_certificate(); + let our_fingerprint = config.fingerprint_of_first_certificate(); let conn = WebRTCConnection::accept( socket_addr, config.into_inner(), udp_mux, - &server_fingerprint, + &our_fingerprint, &ufrag, ) .await?; @@ -543,13 +544,12 @@ async fn upgrade( socket_addr, ufrag ); - let client_fingerprint = conn.get_remote_fingerprint().await; - + let remote_fingerprint = conn.get_remote_fingerprint().await; let peer_id = perform_noise_handshake_inbound( id_keys, Substream::new(data_channel.clone()), - client_fingerprint, - server_fingerprint, + our_fingerprint, + remote_fingerprint, ) .await?; @@ -565,8 +565,8 @@ async fn upgrade( async fn perform_noise_handshake_inbound( id_keys: identity::Keypair, poll_data_channel: T, - client_fingerprint: Fingerprint, - server_fingerprint: Fingerprint, + our_fingerprint: Fingerprint, + remote_fingerprint: Fingerprint, ) -> Result where T: AsyncRead + AsyncWrite + Unpin + Send + 'static, @@ -574,8 +574,8 @@ where let dh_keys = Keypair::::new() .into_authentic(&id_keys) .unwrap(); - let noise = NoiseConfig::xx(dh_keys) - .with_prologue(noise_prologue(client_fingerprint, server_fingerprint)); + let noise = + NoiseConfig::xx(dh_keys).with_prologue(noise_prologue(our_fingerprint, remote_fingerprint)); let info = noise.protocol_info().next().unwrap(); let (peer_id, _noise_io) = noise .into_authenticated() @@ -584,17 +584,18 @@ where Ok(peer_id) } -fn noise_prologue(client: Fingerprint, server: Fingerprint) -> Vec { - let server = server.to_multi_hash().to_bytes(); - let client = client.to_multi_hash().to_bytes(); - +fn noise_prologue(our_fingerprint: Fingerprint, remote_fingerprint: Fingerprint) -> Vec { + let (a, b): (Multihash, Multihash) = ( + our_fingerprint.to_multi_hash(), + remote_fingerprint.to_multi_hash(), + ); + let (a, b) = (a.to_bytes(), b.to_bytes()); + let (first, second) = if a < b { (a, b) } else { (b, a) }; const PREFIX: &[u8] = b"libp2p-webrtc-noise:"; - - let mut out = Vec::with_capacity(PREFIX.len() + server.len() + client.len()); + let mut out = Vec::with_capacity(PREFIX.len() + first.len() + second.len()); out.extend_from_slice(PREFIX); - out.extend_from_slice(&server); - out.extend_from_slice(&client); - + out.extend_from_slice(&first); + out.extend_from_slice(&second); out } @@ -623,7 +624,10 @@ mod tests { let prologue2 = noise_prologue(b, a); assert_eq!(hex::encode(&prologue1), "6c69627032702d7765627274632d6e6f6973653a122030fc9f469c207419dfdd0aab5f27a86c973c94e40548db9375cca2e915973b9912203e79af40d6059617a0d83b83a52ce73b0c1f37a72c6043ad2969e2351bdca870"); - assert_eq!(hex::encode(&prologue2), "6c69627032702d7765627274632d6e6f6973653a12203e79af40d6059617a0d83b83a52ce73b0c1f37a72c6043ad2969e2351bdca870122030fc9f469c207419dfdd0aab5f27a86c973c94e40548db9375cca2e915973b99"); + assert_eq!( + prologue1, prologue2, + "order of fingerprints does not matter" + ); } #[test] From d2da79332036c4767f049ebfb2d683a6fd6bfbca Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Tue, 11 Oct 2022 17:02:28 +1100 Subject: [PATCH 26/39] Add initial test suite for substream state machine --- transports/webrtc/src/substream.rs | 346 +++++++++++++++-------------- 1 file changed, 184 insertions(+), 162 deletions(-) diff --git a/transports/webrtc/src/substream.rs b/transports/webrtc/src/substream.rs index c78196e2fa9..c23278e925b 100644 --- a/transports/webrtc/src/substream.rs +++ b/transports/webrtc/src/substream.rs @@ -28,7 +28,7 @@ use webrtc::data::data_channel::DataChannel; use webrtc::data::data_channel::PollDataChannel; use std::{ - fmt, io, + io, pin::Pin, sync::Arc, task::{Context, Poll}, @@ -53,6 +53,7 @@ const MAX_DATA_LEN: usize = MAX_MSG_LEN - VARINT_LEN - PROTO_OVERHEAD; pub struct Substream { io: Framed, prost_codec::Codec>, state: State, + read_buffer: Bytes, } impl Substream { @@ -66,15 +67,39 @@ impl Substream { Self { io: Framed::new(inner.compat(), prost_codec::Codec::new(MAX_MSG_LEN)), - state: State::Open { - read_buffer: Default::default(), - }, + state: State::Open, + read_buffer: Bytes::default(), } } fn stream_identifier(&self) -> u16 { self.io.get_ref().stream_identifier() } + + /// Gracefully closes the "read-half" of the substream. + pub fn poll_close_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + match self.state.close_read_barrier()? { + Closing::Requested => { + ready!(self.io.poll_ready_unpin(cx))?; + + self.io.start_send_unpin(Message { + flag: Some(Flag::StopSending.into()), + message: None, + })?; + + continue; + } + Closing::MessageSent => { + ready!(self.io.poll_flush_unpin(cx))?; + + self.state.handle_outbound_flag(Flag::StopSending); + + return Poll::Ready(Ok(())); + } + } + } + } } impl AsyncRead for Substream { @@ -84,43 +109,29 @@ impl AsyncRead for Substream { buf: &mut [u8], ) -> Poll> { loop { - if let Some(read_buffer) = self.state.non_empty_read_buffer_mut() { - let n = std::cmp::min(read_buffer.len(), buf.len()); - let data = read_buffer.split_to(n); + self.state.read_barrier()?; + + if !self.read_buffer.is_empty() { + let n = std::cmp::min(self.read_buffer.len(), buf.len()); + let data = self.read_buffer.split_to(n); buf[0..n].copy_from_slice(&data[..]); return Poll::Ready(Ok(n)); } - let substream_id = self.stream_identifier(); - let Self { state, io } = &mut *self; - - let read_buffer = match state { - State::Open { read_buffer } | State::WriteClosed { read_buffer } => read_buffer, - State::ReadClosed { read_buffer, .. } - | State::ReadWriteClosed { read_buffer, .. } => { - assert!(read_buffer.is_empty()); - return Poll::Ready(Ok(0)); - } - State::ReadReset | State::ReadResetWriteClosed => { - return Poll::Ready(Err(io::Error::from(io::ErrorKind::ConnectionReset))); - } - State::Poisoned => unreachable!(), - }; - - match ready!(io_poll_next(io, cx))? { + match ready!(io_poll_next(&mut self.io, cx))? { Some((flag, message)) => { - assert!(read_buffer.is_empty()); - if let Some(message) = message { - *read_buffer = message.into(); + if let Some(flag) = flag { + self.state.handle_inbound_flag(flag); } - if let Some(flag) = flag { - self.state.handle_flag(flag, substream_id) - }; + debug_assert!(self.read_buffer.is_empty()); + if let Some(message) = message { + self.read_buffer = message.into(); + } } None => { - self.state.handle_flag(Flag::Fin, substream_id); + self.state.handle_inbound_flag(Flag::Fin); return Poll::Ready(Ok(0)); } } @@ -134,9 +145,7 @@ impl AsyncWrite for Substream { cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - let substream_id = self.stream_identifier(); - // Handle flags iff read side closed. - while let State::ReadClosed { .. } | State::ReadReset = self.state { + while self.state.read_flags_in_async_write() { // TODO: In case AsyncRead::poll_read encountered an error or returned None earlier, we will poll the // underlying I/O resource once more. Is that allowed? How about introducing a state IoReadClosed? match io_poll_next(&mut self.io, cx)? { @@ -144,20 +153,14 @@ impl AsyncWrite for Substream { // Read side is closed. Discard any incoming messages. drop(message); // But still handle flags, e.g. a `Flag::StopSending`. - self.state.handle_flag(flag, substream_id) + self.state.handle_inbound_flag(flag) } Poll::Ready(Some((None, message))) => drop(message), Poll::Ready(None) | Poll::Pending => break, } } - match self.state { - State::WriteClosed { .. } - | State::ReadWriteClosed { .. } - | State::ReadResetWriteClosed => return Poll::Ready(Ok(0)), - State::Open { .. } | State::ReadClosed { .. } | State::ReadReset => {} - State::Poisoned => todo!(), - } + self.state.write_barrier()?; ready!(self.io.poll_ready_unpin(cx))?; @@ -177,38 +180,27 @@ impl AsyncWrite for Substream { } fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match &self.state { - State::WriteClosed { .. } - | State::ReadWriteClosed { .. } - | State::ReadResetWriteClosed { .. } => {} - - State::Open { .. } | State::ReadClosed { .. } | State::ReadReset => { - ready!(self.io.poll_ready_unpin(cx))?; - Pin::new(&mut self.io).start_send(Message { - flag: Some(Flag::Fin.into()), - message: None, - })?; - - match std::mem::replace(&mut self.state, State::Poisoned) { - State::Open { read_buffer } => self.state = State::WriteClosed { read_buffer }, - State::ReadClosed { read_buffer } => { - self.state = State::ReadWriteClosed { read_buffer } - } - State::ReadReset => self.state = State::ReadResetWriteClosed, - State::WriteClosed { .. } - | State::ReadWriteClosed { .. } - | State::ReadResetWriteClosed - | State::Poisoned => { - unreachable!() - } + loop { + match self.state.close_write_barrier()? { + Closing::Requested => { + ready!(self.io.poll_ready_unpin(cx))?; + + self.io.start_send_unpin(Message { + flag: Some(Flag::Fin.into()), + message: None, + })?; + + continue; } - } + Closing::MessageSent => { + ready!(self.io.poll_flush_unpin(cx))?; - State::Poisoned => todo!(), - } + self.state.handle_outbound_flag(Flag::Fin); - // TODO: Is flush the correct thing here? We don't want the underlying layer to close both write and read. - self.io.poll_flush_unpin(cx).map_err(Into::into) + return Poll::Ready(Ok(())); + } + } + } } } @@ -234,106 +226,56 @@ fn io_poll_next( } enum State { - Open { read_buffer: Bytes }, - WriteClosed { read_buffer: Bytes }, - ReadClosed { read_buffer: Bytes }, - ReadWriteClosed { read_buffer: Bytes }, - ReadReset, - ReadResetWriteClosed, - Poisoned, + Open, + ReadClosed, + WriteClosed, + ClosingRead(Closing), + ClosingWrite(Closing), + BothClosed { reset: bool }, } -impl State { - fn handle_flag(&mut self, flag: Flag, substream_id: u16) { - let old_state = format!("{}", self); - match (std::mem::replace(self, State::Poisoned), flag) { - // StopSending - ( - State::Open { read_buffer } | State::WriteClosed { read_buffer }, - Flag::StopSending, - ) => { - *self = State::WriteClosed { read_buffer }; - } - - ( - State::ReadClosed { read_buffer } | State::ReadWriteClosed { read_buffer }, - Flag::StopSending, - ) => { - *self = State::ReadWriteClosed { read_buffer }; - } - - (State::ReadReset | State::ReadResetWriteClosed, Flag::StopSending) => { - *self = State::ReadResetWriteClosed; - } - - // Fin - (State::Open { read_buffer } | State::ReadClosed { read_buffer }, Flag::Fin) => { - *self = State::ReadClosed { read_buffer }; - } - - ( - State::WriteClosed { read_buffer } | State::ReadWriteClosed { read_buffer }, - Flag::Fin, - ) => { - *self = State::ReadWriteClosed { read_buffer }; - } - - (State::ReadReset, Flag::Fin) => *self = State::ReadReset, - - (State::ReadResetWriteClosed, Flag::Fin) => *self = State::ReadResetWriteClosed, - - // Reset - (State::ReadClosed { .. } | State::ReadReset | State::Open { .. }, Flag::Reset) => { - *self = State::ReadReset - } +/// Represents the state of closing one half (either read or write) of the connection. +/// +/// Gracefully closing the read or write requires sending the `STOP_SENDING` or `FIN` flag respectively +/// and flushing the underlying connection. +enum Closing { + Requested, + MessageSent, +} - ( - State::ReadWriteClosed { .. } - | State::WriteClosed { .. } - | State::ReadResetWriteClosed, - Flag::Reset, - ) => *self = State::ReadResetWriteClosed, +impl State { + /// Performs a state transition for a flag contained in an inbound message. + fn handle_inbound_flag(&mut self, flag: Flag) {} + + /// Performs a state transition for a flag contained in an outbound message. + fn handle_outbound_flag(&mut self, flag: Flag) {} + + /// Whether we should read from the stream in the [`AsyncWrite`] implementation. + /// + /// This is necessary for read-closed streams because we would otherwise not read any more flags from + /// the socket. + fn read_flags_in_async_write(&self) -> bool { + false + } - (State::Poisoned, _) => unreachable!(), - } + /// Acts as a "barrier" for [`AsyncRead::poll_read`]. + fn read_barrier(&self) -> io::Result<()> { + Ok(()) + } - log::debug!( - "substream={}: got flag {:?}, moved from {} to {}", - substream_id, - flag, - old_state, - *self - ); + /// Acts as a "barrier" for [`AsyncWrite::poll_write`]. + fn write_barrier(&self) -> io::Result<()> { + Ok(()) } - /// Returns a reference to the underlying buffer if possible and the buffer is not empty. - fn non_empty_read_buffer_mut(&mut self) -> Option<&mut Bytes> { - match self { - State::Open { read_buffer } - | State::WriteClosed { read_buffer } - | State::ReadClosed { read_buffer } - | State::ReadWriteClosed { read_buffer } - if !read_buffer.is_empty() => - { - Some(read_buffer) - } - State::Poisoned => unreachable!(), - _ => None, - } + /// Acts as a "barrier" for [`AsyncWrite::poll_close`]. + fn close_write_barrier(&self) -> io::Result { + todo!() } -} -impl fmt::Display for State { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match *self { - State::Open { .. } => write!(f, "Open"), - State::WriteClosed { .. } => write!(f, "WriteClosed"), - State::ReadClosed { .. } => write!(f, "ReadClosed"), - State::ReadWriteClosed { .. } => write!(f, "ReadWriteClosed"), - State::ReadReset => write!(f, "ReadReset"), - State::ReadResetWriteClosed => write!(f, "ReadResetWriteClosed"), - State::Poisoned => write!(f, "Poisoned"), - } + /// Acts as a "barrier" for [`Substream::poll_close_read`]. + fn close_read_barrier(&self) -> io::Result { + todo!() } } @@ -343,8 +285,88 @@ mod tests { use asynchronous_codec::Encoder; use bytes::BytesMut; use prost::Message; + use std::io::ErrorKind; use unsigned_varint::codec::UviBytes; + #[test] + fn cannot_read_after_receiving_fin() { + let mut open = State::Open; + + open.handle_inbound_flag(Flag::Fin); + let error = open.read_barrier().unwrap_err(); + + assert_eq!(error.kind(), ErrorKind::BrokenPipe) + } + + #[test] + fn cannot_read_after_sending_stop_sending() { + let mut open = State::Open; + + open.handle_outbound_flag(Flag::StopSending); + let error = open.read_barrier().unwrap_err(); + + assert_eq!(error.kind(), ErrorKind::BrokenPipe) + } + + #[test] + fn cannot_write_after_receiving_stop_sending() { + let mut open = State::Open; + + open.handle_inbound_flag(Flag::StopSending); + let error = open.write_barrier().unwrap_err(); + + assert_eq!(error.kind(), ErrorKind::BrokenPipe) + } + + #[test] + fn cannot_write_after_sending_fin() { + let mut open = State::Open; + + open.handle_outbound_flag(Flag::Fin); + let error = open.write_barrier().unwrap_err(); + + assert_eq!(error.kind(), ErrorKind::BrokenPipe) + } + + #[test] + fn everything_broken_after_receiving_reset() { + let mut open = State::Open; + + open.handle_inbound_flag(Flag::Reset); + let error1 = open.read_barrier().unwrap_err(); + let error2 = open.write_barrier().unwrap_err(); + let error3 = open.close_write_barrier().unwrap_err(); + let error4 = open.close_read_barrier().unwrap_err(); + + assert_eq!(error1.kind(), ErrorKind::ConnectionReset); + assert_eq!(error2.kind(), ErrorKind::ConnectionReset); + assert_eq!(error3.kind(), ErrorKind::ConnectionReset); + assert_eq!(error4.kind(), ErrorKind::ConnectionReset); + } + + #[test] + fn should_read_flags_in_async_write_after_read_closed() { + let mut open = State::Open; + + open.handle_inbound_flag(Flag::Fin); + + assert!(open.read_flags_in_async_write()) + } + + #[test] + fn cannot_read_or_write_after_receiving_fin_and_stop_sending() { + let mut open = State::Open; + + open.handle_inbound_flag(Flag::Fin); + open.handle_inbound_flag(Flag::StopSending); + + let error1 = open.read_barrier().unwrap_err(); + let error2 = open.write_barrier().unwrap_err(); + + assert_eq!(error1.kind(), ErrorKind::BrokenPipe); + assert_eq!(error2.kind(), ErrorKind::BrokenPipe); + } + #[test] fn max_data_len() { // Largest possible message. From 1b520a9598620e4df32ef650d2009c612a3b8413 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Tue, 11 Oct 2022 17:45:28 +1100 Subject: [PATCH 27/39] Precompute substream ID --- transports/webrtc/src/substream.rs | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/transports/webrtc/src/substream.rs b/transports/webrtc/src/substream.rs index c23278e925b..45ee66cd6ac 100644 --- a/transports/webrtc/src/substream.rs +++ b/transports/webrtc/src/substream.rs @@ -54,6 +54,7 @@ pub struct Substream { io: Framed, prost_codec::Codec>, state: State, read_buffer: Bytes, + substream_id: u16, } impl Substream { @@ -65,17 +66,17 @@ impl Substream { // https://github.com/webrtc-rs/webrtc/issues/273 is fixed. inner.set_read_buf_capacity(8192 * 10); + let io = Framed::new(inner.compat(), prost_codec::Codec::new(MAX_MSG_LEN)); + let substream_id = io.get_ref().stream_identifier(); + Self { - io: Framed::new(inner.compat(), prost_codec::Codec::new(MAX_MSG_LEN)), + io, state: State::Open, read_buffer: Bytes::default(), + substream_id, } } - fn stream_identifier(&self) -> u16 { - self.io.get_ref().stream_identifier() - } - /// Gracefully closes the "read-half" of the substream. pub fn poll_close_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { loop { From e483974276cd4617e71b8ba56993dd214cb60551 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Tue, 11 Oct 2022 17:51:49 +1100 Subject: [PATCH 28/39] Implement new state machine --- transports/webrtc/src/substream.rs | 404 ++++++++++++++++++++++++++--- 1 file changed, 372 insertions(+), 32 deletions(-) diff --git a/transports/webrtc/src/substream.rs b/transports/webrtc/src/substream.rs index 45ee66cd6ac..63ba6477ed2 100644 --- a/transports/webrtc/src/substream.rs +++ b/transports/webrtc/src/substream.rs @@ -81,23 +81,25 @@ impl Substream { pub fn poll_close_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { loop { match self.state.close_read_barrier()? { - Closing::Requested => { + Some(Closing::Requested) => { ready!(self.io.poll_ready_unpin(cx))?; self.io.start_send_unpin(Message { flag: Some(Flag::StopSending.into()), message: None, })?; + self.state.close_read_message_sent(); continue; } - Closing::MessageSent => { + Some(Closing::MessageSent) => { ready!(self.io.poll_flush_unpin(cx))?; - self.state.handle_outbound_flag(Flag::StopSending); + self.state.read_closed(); return Poll::Ready(Ok(())); } + None => return Poll::Ready(Ok(())), } } } @@ -120,10 +122,11 @@ impl AsyncRead for Substream { return Poll::Ready(Ok(n)); } + let substream_id = self.substream_id; match ready!(io_poll_next(&mut self.io, cx))? { Some((flag, message)) => { if let Some(flag) = flag { - self.state.handle_inbound_flag(flag); + self.state.handle_inbound_flag(flag, substream_id); } debug_assert!(self.read_buffer.is_empty()); @@ -132,7 +135,7 @@ impl AsyncRead for Substream { } } None => { - self.state.handle_inbound_flag(Flag::Fin); + self.state.handle_inbound_flag(Flag::Fin, substream_id); return Poll::Ready(Ok(0)); } } @@ -149,12 +152,14 @@ impl AsyncWrite for Substream { while self.state.read_flags_in_async_write() { // TODO: In case AsyncRead::poll_read encountered an error or returned None earlier, we will poll the // underlying I/O resource once more. Is that allowed? How about introducing a state IoReadClosed? + let substream_id = self.substream_id; + match io_poll_next(&mut self.io, cx)? { Poll::Ready(Some((Some(flag), message))) => { // Read side is closed. Discard any incoming messages. drop(message); // But still handle flags, e.g. a `Flag::StopSending`. - self.state.handle_inbound_flag(flag) + self.state.handle_inbound_flag(flag, substream_id) } Poll::Ready(Some((None, message))) => drop(message), Poll::Ready(None) | Poll::Pending => break, @@ -183,23 +188,25 @@ impl AsyncWrite for Substream { fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { loop { match self.state.close_write_barrier()? { - Closing::Requested => { + Some(Closing::Requested) => { ready!(self.io.poll_ready_unpin(cx))?; self.io.start_send_unpin(Message { flag: Some(Flag::Fin.into()), message: None, })?; + self.state.close_write_message_sent(); continue; } - Closing::MessageSent => { + Some(Closing::MessageSent) => { ready!(self.io.poll_flush_unpin(cx))?; - self.state.handle_outbound_flag(Flag::Fin); + self.state.write_closed(); return Poll::Ready(Ok(())); } + None => return Poll::Ready(Ok(())), } } } @@ -226,19 +233,31 @@ fn io_poll_next( } } +#[derive(Debug, Copy, Clone)] enum State { Open, ReadClosed, WriteClosed, - ClosingRead(Closing), - ClosingWrite(Closing), - BothClosed { reset: bool }, + ClosingRead { + /// Whether the write side of our channel was already closed. + write_closed: bool, + inner: Closing, + }, + ClosingWrite { + /// Whether the write side of our channel was already closed. + read_closed: bool, + inner: Closing, + }, + BothClosed { + reset: bool, + }, } /// Represents the state of closing one half (either read or write) of the connection. /// /// Gracefully closing the read or write requires sending the `STOP_SENDING` or `FIN` flag respectively /// and flushing the underlying connection. +#[derive(Debug, Copy, Clone)] enum Closing { Requested, MessageSent, @@ -246,37 +265,268 @@ enum Closing { impl State { /// Performs a state transition for a flag contained in an inbound message. - fn handle_inbound_flag(&mut self, flag: Flag) {} + fn handle_inbound_flag(&mut self, flag: Flag, substream_id: u16) { + let current = *self; + + match (current, flag) { + (Self::Open, Flag::Fin) => { + *self = Self::ReadClosed; + } + (Self::WriteClosed, Flag::Fin) => { + *self = Self::BothClosed { reset: false }; + } + (Self::Open, Flag::StopSending) => { + *self = Self::WriteClosed; + } + (Self::ReadClosed, Flag::StopSending) => { + *self = Self::BothClosed { reset: false }; + } + (_, Flag::Reset) => *self = Self::BothClosed { reset: true }, + _ => {} + } + + log::trace!("Transitioned from {current:?} to {self:?} on substream {substream_id}") + } + + fn write_closed(&mut self) { + match self { + State::ClosingWrite { + read_closed: true, + inner, + } => { + debug_assert!(matches!(inner, Closing::MessageSent)); + + *self = State::BothClosed { reset: false }; + } + State::ClosingWrite { + read_closed: false, + inner, + } => { + debug_assert!(matches!(inner, Closing::MessageSent)); + + *self = State::WriteClosed; + } + State::Open + | State::ReadClosed + | State::WriteClosed + | State::ClosingRead { .. } + | State::BothClosed { .. } => { + unreachable!("bad state machine impl") + } + } + } + + fn close_write_message_sent(&mut self) { + match self { + State::ClosingWrite { inner, read_closed } => { + debug_assert!(matches!(inner, Closing::Requested)); + + *self = State::ClosingWrite { + read_closed: *read_closed, + inner: Closing::MessageSent, + }; + } + State::Open + | State::ReadClosed + | State::WriteClosed + | State::ClosingRead { .. } + | State::BothClosed { .. } => { + unreachable!("bad state machine impl") + } + } + } + + fn read_closed(&mut self) { + match self { + State::ClosingRead { + write_closed: true, + inner, + } => { + debug_assert!(matches!(inner, Closing::MessageSent)); + + *self = State::BothClosed { reset: false }; + } + State::ClosingRead { + write_closed: false, + inner, + } => { + debug_assert!(matches!(inner, Closing::MessageSent)); + + *self = State::ReadClosed; + } + State::Open + | State::ReadClosed + | State::WriteClosed + | State::ClosingWrite { .. } + | State::BothClosed { .. } => { + unreachable!("bad state machine impl") + } + } + } - /// Performs a state transition for a flag contained in an outbound message. - fn handle_outbound_flag(&mut self, flag: Flag) {} + fn close_read_message_sent(&mut self) { + match self { + State::ClosingRead { + inner, + write_closed, + } => { + debug_assert!(matches!(inner, Closing::Requested)); + + *self = State::ClosingRead { + write_closed: *write_closed, + inner: Closing::MessageSent, + }; + } + State::Open + | State::ReadClosed + | State::WriteClosed + | State::ClosingWrite { .. } + | State::BothClosed { .. } => { + unreachable!("bad state machine impl") + } + } + } /// Whether we should read from the stream in the [`AsyncWrite`] implementation. /// /// This is necessary for read-closed streams because we would otherwise not read any more flags from /// the socket. fn read_flags_in_async_write(&self) -> bool { - false + matches!(self, Self::ReadClosed) } /// Acts as a "barrier" for [`AsyncRead::poll_read`]. fn read_barrier(&self) -> io::Result<()> { - Ok(()) + use State::*; + + let kind = match self { + Open + | WriteClosed + | ClosingWrite { + read_closed: false, .. + } => return Ok(()), + ClosingWrite { + read_closed: true, .. + } + | ReadClosed + | ClosingRead { .. } + | BothClosed { reset: false } => io::ErrorKind::BrokenPipe, + BothClosed { reset: true } => io::ErrorKind::ConnectionReset, + }; + + Err(kind.into()) } /// Acts as a "barrier" for [`AsyncWrite::poll_write`]. fn write_barrier(&self) -> io::Result<()> { - Ok(()) + use State::*; + + let kind = match self { + Open + | ReadClosed + | ClosingRead { + write_closed: false, + .. + } => return Ok(()), + ClosingRead { + write_closed: true, .. + } + | WriteClosed + | ClosingWrite { .. } + | BothClosed { reset: false } => io::ErrorKind::BrokenPipe, + BothClosed { reset: true } => io::ErrorKind::ConnectionReset, + }; + + Err(kind.into()) } /// Acts as a "barrier" for [`AsyncWrite::poll_close`]. - fn close_write_barrier(&self) -> io::Result { - todo!() + fn close_write_barrier(&mut self) -> io::Result> { + loop { + match &self { + State::WriteClosed => return Ok(None), + + State::ClosingWrite { inner, .. } => return Ok(Some(*inner)), + + State::Open => { + *self = Self::ClosingWrite { + read_closed: false, + inner: Closing::Requested, + }; + } + State::ReadClosed => { + *self = Self::ClosingWrite { + read_closed: true, + inner: Closing::Requested, + }; + } + + State::ClosingRead { + write_closed: true, .. + } + | State::BothClosed { reset: false } => { + return Err(io::ErrorKind::BrokenPipe.into()) + } + + State::ClosingRead { + write_closed: false, + .. + } => { + return Err(io::Error::new( + io::ErrorKind::Other, + "cannot close read half while closing write half", + )) + } + + State::BothClosed { reset: true } => { + return Err(io::ErrorKind::ConnectionReset.into()) + } + } + } } /// Acts as a "barrier" for [`Substream::poll_close_read`]. - fn close_read_barrier(&self) -> io::Result { - todo!() + fn close_read_barrier(&mut self) -> io::Result> { + loop { + match self { + State::ReadClosed => return Ok(None), + + State::ClosingRead { inner, .. } => return Ok(Some(*inner)), + + State::Open => { + *self = Self::ClosingRead { + write_closed: false, + inner: Closing::Requested, + }; + } + State::WriteClosed => { + *self = Self::ClosingRead { + write_closed: true, + inner: Closing::Requested, + }; + } + + State::ClosingWrite { + read_closed: true, .. + } + | State::BothClosed { reset: false } => { + return Err(io::ErrorKind::BrokenPipe.into()) + } + + State::ClosingWrite { + read_closed: false, .. + } => { + return Err(io::Error::new( + io::ErrorKind::Other, + "cannot close write half while closing read half", + )) + } + + State::BothClosed { reset: true } => { + return Err(io::ErrorKind::ConnectionReset.into()) + } + } + } } } @@ -293,17 +543,19 @@ mod tests { fn cannot_read_after_receiving_fin() { let mut open = State::Open; - open.handle_inbound_flag(Flag::Fin); + open.handle_inbound_flag(Flag::Fin, 0); let error = open.read_barrier().unwrap_err(); assert_eq!(error.kind(), ErrorKind::BrokenPipe) } #[test] - fn cannot_read_after_sending_stop_sending() { + fn cannot_read_after_closing_read() { let mut open = State::Open; - open.handle_outbound_flag(Flag::StopSending); + open.close_read_barrier().unwrap(); + open.close_read_message_sent(); + open.read_closed(); let error = open.read_barrier().unwrap_err(); assert_eq!(error.kind(), ErrorKind::BrokenPipe) @@ -313,17 +565,19 @@ mod tests { fn cannot_write_after_receiving_stop_sending() { let mut open = State::Open; - open.handle_inbound_flag(Flag::StopSending); + open.handle_inbound_flag(Flag::StopSending, 0); let error = open.write_barrier().unwrap_err(); assert_eq!(error.kind(), ErrorKind::BrokenPipe) } #[test] - fn cannot_write_after_sending_fin() { + fn cannot_write_after_closing_write() { let mut open = State::Open; - open.handle_outbound_flag(Flag::Fin); + open.close_write_barrier().unwrap(); + open.close_write_message_sent(); + open.write_closed(); let error = open.write_barrier().unwrap_err(); assert_eq!(error.kind(), ErrorKind::BrokenPipe) @@ -333,7 +587,7 @@ mod tests { fn everything_broken_after_receiving_reset() { let mut open = State::Open; - open.handle_inbound_flag(Flag::Reset); + open.handle_inbound_flag(Flag::Reset, 0); let error1 = open.read_barrier().unwrap_err(); let error2 = open.write_barrier().unwrap_err(); let error3 = open.close_write_barrier().unwrap_err(); @@ -349,7 +603,7 @@ mod tests { fn should_read_flags_in_async_write_after_read_closed() { let mut open = State::Open; - open.handle_inbound_flag(Flag::Fin); + open.handle_inbound_flag(Flag::Fin, 0); assert!(open.read_flags_in_async_write()) } @@ -358,8 +612,8 @@ mod tests { fn cannot_read_or_write_after_receiving_fin_and_stop_sending() { let mut open = State::Open; - open.handle_inbound_flag(Flag::Fin); - open.handle_inbound_flag(Flag::StopSending); + open.handle_inbound_flag(Flag::Fin, 0); + open.handle_inbound_flag(Flag::StopSending, 0); let error1 = open.read_barrier().unwrap_err(); let error2 = open.write_barrier().unwrap_err(); @@ -368,6 +622,92 @@ mod tests { assert_eq!(error2.kind(), ErrorKind::BrokenPipe); } + #[test] + fn can_read_after_closing_write() { + let mut open = State::Open; + + open.close_write_barrier().unwrap(); + open.close_write_message_sent(); + open.write_closed(); + + open.read_barrier().unwrap(); + } + + #[test] + fn can_write_after_closing_read() { + let mut open = State::Open; + + open.close_read_barrier().unwrap(); + open.close_read_message_sent(); + open.read_closed(); + + open.write_barrier().unwrap(); + } + + #[test] + fn cannot_write_after_starting_close() { + let mut open = State::Open; + + open.close_write_barrier().expect("to close in open"); + let error = open.write_barrier().unwrap_err(); + + assert_eq!(error.kind(), ErrorKind::BrokenPipe); + } + + #[test] + fn cannot_read_after_starting_close() { + let mut open = State::Open; + + open.close_read_barrier().expect("to close in open"); + let error = open.read_barrier().unwrap_err(); + + assert_eq!(error.kind(), ErrorKind::BrokenPipe); + } + + #[test] + fn can_read_in_open() { + let open = State::Open; + + let result = open.read_barrier(); + + result.unwrap(); + } + + #[test] + fn can_write_in_open() { + let open = State::Open; + + let result = open.write_barrier(); + + result.unwrap(); + } + + #[test] + fn write_close_barrier_returns_ok_when_closed() { + let mut open = State::Open; + + open.close_write_barrier().unwrap(); + open.close_write_message_sent(); + open.write_closed(); + + let maybe = open.close_write_barrier().unwrap(); + + assert!(maybe.is_none()) + } + + #[test] + fn read_close_barrier_returns_ok_when_closed() { + let mut open = State::Open; + + open.close_read_barrier().unwrap(); + open.close_read_message_sent(); + open.read_closed(); + + let maybe = open.close_read_barrier().unwrap(); + + assert!(maybe.is_none()) + } + #[test] fn max_data_len() { // Largest possible message. From 29d6f742b55e748100590c5956928992afdb704e Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Tue, 11 Oct 2022 18:12:46 +1100 Subject: [PATCH 29/39] Reset flag clears buffer --- transports/webrtc/src/substream.rs | 59 +++++++++++++++++++++--------- 1 file changed, 42 insertions(+), 17 deletions(-) diff --git a/transports/webrtc/src/substream.rs b/transports/webrtc/src/substream.rs index 63ba6477ed2..4ae63198290 100644 --- a/transports/webrtc/src/substream.rs +++ b/transports/webrtc/src/substream.rs @@ -122,20 +122,26 @@ impl AsyncRead for Substream { return Poll::Ready(Ok(n)); } - let substream_id = self.substream_id; - match ready!(io_poll_next(&mut self.io, cx))? { + let Self { + substream_id, + read_buffer, + io, + state, + } = &mut *self; + + match ready!(io_poll_next(io, cx))? { Some((flag, message)) => { if let Some(flag) = flag { - self.state.handle_inbound_flag(flag, substream_id); + state.handle_inbound_flag(flag, read_buffer, *substream_id); } - debug_assert!(self.read_buffer.is_empty()); + debug_assert!(read_buffer.is_empty()); if let Some(message) = message { - self.read_buffer = message.into(); + *read_buffer = message.into(); } } None => { - self.state.handle_inbound_flag(Flag::Fin, substream_id); + state.handle_inbound_flag(Flag::Fin, read_buffer, *substream_id); return Poll::Ready(Ok(0)); } } @@ -152,14 +158,20 @@ impl AsyncWrite for Substream { while self.state.read_flags_in_async_write() { // TODO: In case AsyncRead::poll_read encountered an error or returned None earlier, we will poll the // underlying I/O resource once more. Is that allowed? How about introducing a state IoReadClosed? - let substream_id = self.substream_id; - match io_poll_next(&mut self.io, cx)? { + let Self { + substream_id, + read_buffer, + io, + state, + } = &mut *self; + + match io_poll_next(io, cx)? { Poll::Ready(Some((Some(flag), message))) => { // Read side is closed. Discard any incoming messages. drop(message); // But still handle flags, e.g. a `Flag::StopSending`. - self.state.handle_inbound_flag(flag, substream_id) + state.handle_inbound_flag(flag, read_buffer, *substream_id) } Poll::Ready(Some((None, message))) => drop(message), Poll::Ready(None) | Poll::Pending => break, @@ -265,7 +277,7 @@ enum Closing { impl State { /// Performs a state transition for a flag contained in an inbound message. - fn handle_inbound_flag(&mut self, flag: Flag, substream_id: u16) { + fn handle_inbound_flag(&mut self, flag: Flag, buffer: &mut Bytes, substream_id: u16) { let current = *self; match (current, flag) { @@ -281,7 +293,10 @@ impl State { (Self::ReadClosed, Flag::StopSending) => { *self = Self::BothClosed { reset: false }; } - (_, Flag::Reset) => *self = Self::BothClosed { reset: true }, + (_, Flag::Reset) => { + buffer.clear(); + *self = Self::BothClosed { reset: true }; + } _ => {} } @@ -543,7 +558,7 @@ mod tests { fn cannot_read_after_receiving_fin() { let mut open = State::Open; - open.handle_inbound_flag(Flag::Fin, 0); + open.handle_inbound_flag(Flag::Fin, &mut Bytes::default(), 0); let error = open.read_barrier().unwrap_err(); assert_eq!(error.kind(), ErrorKind::BrokenPipe) @@ -565,7 +580,7 @@ mod tests { fn cannot_write_after_receiving_stop_sending() { let mut open = State::Open; - open.handle_inbound_flag(Flag::StopSending, 0); + open.handle_inbound_flag(Flag::StopSending, &mut Bytes::default(), 0); let error = open.write_barrier().unwrap_err(); assert_eq!(error.kind(), ErrorKind::BrokenPipe) @@ -587,7 +602,7 @@ mod tests { fn everything_broken_after_receiving_reset() { let mut open = State::Open; - open.handle_inbound_flag(Flag::Reset, 0); + open.handle_inbound_flag(Flag::Reset, &mut Bytes::default(), 0); let error1 = open.read_barrier().unwrap_err(); let error2 = open.write_barrier().unwrap_err(); let error3 = open.close_write_barrier().unwrap_err(); @@ -603,7 +618,7 @@ mod tests { fn should_read_flags_in_async_write_after_read_closed() { let mut open = State::Open; - open.handle_inbound_flag(Flag::Fin, 0); + open.handle_inbound_flag(Flag::Fin, &mut Bytes::default(), 0); assert!(open.read_flags_in_async_write()) } @@ -612,8 +627,8 @@ mod tests { fn cannot_read_or_write_after_receiving_fin_and_stop_sending() { let mut open = State::Open; - open.handle_inbound_flag(Flag::Fin, 0); - open.handle_inbound_flag(Flag::StopSending, 0); + open.handle_inbound_flag(Flag::Fin, &mut Bytes::default(), 0); + open.handle_inbound_flag(Flag::StopSending, &mut Bytes::default(), 0); let error1 = open.read_barrier().unwrap_err(); let error2 = open.write_barrier().unwrap_err(); @@ -708,6 +723,16 @@ mod tests { assert!(maybe.is_none()) } + #[test] + fn reset_flag_clears_buffer() { + let mut open = State::Open; + let mut buffer = Bytes::copy_from_slice(b"foobar"); + + open.handle_inbound_flag(Flag::Reset, &mut buffer, 0); + + assert!(buffer.is_empty()); + } + #[test] fn max_data_len() { // Largest possible message. From d829fda97d15c05a050223f4f2f7f11a52d5ef3c Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Tue, 11 Oct 2022 18:14:02 +1100 Subject: [PATCH 30/39] Remove substream ID Logging these state transitions is no longer really worth it because we have changed the design to have many more functions which would all require logging now. --- transports/webrtc/src/substream.rs | 33 +++++++++++------------------- 1 file changed, 12 insertions(+), 21 deletions(-) diff --git a/transports/webrtc/src/substream.rs b/transports/webrtc/src/substream.rs index 4ae63198290..d0accc85255 100644 --- a/transports/webrtc/src/substream.rs +++ b/transports/webrtc/src/substream.rs @@ -54,7 +54,6 @@ pub struct Substream { io: Framed, prost_codec::Codec>, state: State, read_buffer: Bytes, - substream_id: u16, } impl Substream { @@ -66,14 +65,10 @@ impl Substream { // https://github.com/webrtc-rs/webrtc/issues/273 is fixed. inner.set_read_buf_capacity(8192 * 10); - let io = Framed::new(inner.compat(), prost_codec::Codec::new(MAX_MSG_LEN)); - let substream_id = io.get_ref().stream_identifier(); - Self { - io, + io: Framed::new(inner.compat(), prost_codec::Codec::new(MAX_MSG_LEN)), state: State::Open, read_buffer: Bytes::default(), - substream_id, } } @@ -123,7 +118,6 @@ impl AsyncRead for Substream { } let Self { - substream_id, read_buffer, io, state, @@ -132,7 +126,7 @@ impl AsyncRead for Substream { match ready!(io_poll_next(io, cx))? { Some((flag, message)) => { if let Some(flag) = flag { - state.handle_inbound_flag(flag, read_buffer, *substream_id); + state.handle_inbound_flag(flag, read_buffer); } debug_assert!(read_buffer.is_empty()); @@ -141,7 +135,7 @@ impl AsyncRead for Substream { } } None => { - state.handle_inbound_flag(Flag::Fin, read_buffer, *substream_id); + state.handle_inbound_flag(Flag::Fin, read_buffer); return Poll::Ready(Ok(0)); } } @@ -160,7 +154,6 @@ impl AsyncWrite for Substream { // underlying I/O resource once more. Is that allowed? How about introducing a state IoReadClosed? let Self { - substream_id, read_buffer, io, state, @@ -171,7 +164,7 @@ impl AsyncWrite for Substream { // Read side is closed. Discard any incoming messages. drop(message); // But still handle flags, e.g. a `Flag::StopSending`. - state.handle_inbound_flag(flag, read_buffer, *substream_id) + state.handle_inbound_flag(flag, read_buffer) } Poll::Ready(Some((None, message))) => drop(message), Poll::Ready(None) | Poll::Pending => break, @@ -277,7 +270,7 @@ enum Closing { impl State { /// Performs a state transition for a flag contained in an inbound message. - fn handle_inbound_flag(&mut self, flag: Flag, buffer: &mut Bytes, substream_id: u16) { + fn handle_inbound_flag(&mut self, flag: Flag, buffer: &mut Bytes) { let current = *self; match (current, flag) { @@ -299,8 +292,6 @@ impl State { } _ => {} } - - log::trace!("Transitioned from {current:?} to {self:?} on substream {substream_id}") } fn write_closed(&mut self) { @@ -558,7 +549,7 @@ mod tests { fn cannot_read_after_receiving_fin() { let mut open = State::Open; - open.handle_inbound_flag(Flag::Fin, &mut Bytes::default(), 0); + open.handle_inbound_flag(Flag::Fin, &mut Bytes::default()); let error = open.read_barrier().unwrap_err(); assert_eq!(error.kind(), ErrorKind::BrokenPipe) @@ -580,7 +571,7 @@ mod tests { fn cannot_write_after_receiving_stop_sending() { let mut open = State::Open; - open.handle_inbound_flag(Flag::StopSending, &mut Bytes::default(), 0); + open.handle_inbound_flag(Flag::StopSending, &mut Bytes::default()); let error = open.write_barrier().unwrap_err(); assert_eq!(error.kind(), ErrorKind::BrokenPipe) @@ -602,7 +593,7 @@ mod tests { fn everything_broken_after_receiving_reset() { let mut open = State::Open; - open.handle_inbound_flag(Flag::Reset, &mut Bytes::default(), 0); + open.handle_inbound_flag(Flag::Reset, &mut Bytes::default()); let error1 = open.read_barrier().unwrap_err(); let error2 = open.write_barrier().unwrap_err(); let error3 = open.close_write_barrier().unwrap_err(); @@ -618,7 +609,7 @@ mod tests { fn should_read_flags_in_async_write_after_read_closed() { let mut open = State::Open; - open.handle_inbound_flag(Flag::Fin, &mut Bytes::default(), 0); + open.handle_inbound_flag(Flag::Fin, &mut Bytes::default()); assert!(open.read_flags_in_async_write()) } @@ -627,8 +618,8 @@ mod tests { fn cannot_read_or_write_after_receiving_fin_and_stop_sending() { let mut open = State::Open; - open.handle_inbound_flag(Flag::Fin, &mut Bytes::default(), 0); - open.handle_inbound_flag(Flag::StopSending, &mut Bytes::default(), 0); + open.handle_inbound_flag(Flag::Fin, &mut Bytes::default()); + open.handle_inbound_flag(Flag::StopSending, &mut Bytes::default()); let error1 = open.read_barrier().unwrap_err(); let error2 = open.write_barrier().unwrap_err(); @@ -728,7 +719,7 @@ mod tests { let mut open = State::Open; let mut buffer = Bytes::copy_from_slice(b"foobar"); - open.handle_inbound_flag(Flag::Reset, &mut buffer, 0); + open.handle_inbound_flag(Flag::Reset, &mut buffer); assert!(buffer.is_empty()); } From 058a1531032b829b78003835a595423c4bb68a4f Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Tue, 11 Oct 2022 18:17:41 +1100 Subject: [PATCH 31/39] Remove use of `map_err` --- transports/webrtc/src/connection.rs | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/transports/webrtc/src/connection.rs b/transports/webrtc/src/connection.rs index cb44ccadee0..8c84e10fac3 100644 --- a/transports/webrtc/src/connection.rs +++ b/transports/webrtc/src/connection.rs @@ -24,7 +24,7 @@ use futures::{ oneshot::{self, Sender}, }, lock::Mutex as FutMutex, - {future::BoxFuture, prelude::*, ready}, + {future::BoxFuture, ready}, }; use futures_lite::StreamExt; use libp2p_core::muxing::{StreamMuxer, StreamMuxerEvent}; @@ -171,10 +171,7 @@ impl StreamMuxer for Connection { let peer_conn = peer_conn.lock().await; // Create a datachannel with label 'data' - let data_channel = peer_conn - .create_data_channel("data", None) - .map_err(Error::WebRTC) - .await?; + let data_channel = peer_conn.create_data_channel("data", None).await?; trace!("Opening outbound substream {}", data_channel.id()); @@ -211,7 +208,9 @@ impl StreamMuxer for Connection { let peer_conn = self.peer_conn.clone(); let fut = self.close_fut.get_or_insert(Box::pin(async move { let peer_conn = peer_conn.lock().await; - peer_conn.close().await.map_err(Error::WebRTC) + peer_conn.close().await?; + + Ok(()) })); match ready!(fut.as_mut().poll(cx)) { From e1df3c4889fea508e0a05daf287326134315c0e4 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Tue, 11 Oct 2022 18:21:40 +1100 Subject: [PATCH 32/39] Replace error with `Poll::Pending` --- transports/webrtc/src/connection.rs | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/transports/webrtc/src/connection.rs b/transports/webrtc/src/connection.rs index 8c84e10fac3..4d867ceeae4 100644 --- a/transports/webrtc/src/connection.rs +++ b/transports/webrtc/src/connection.rs @@ -149,9 +149,14 @@ impl StreamMuxer for Connection { Poll::Ready(Ok(Substream::new(detached))) } - None => Poll::Ready(Err(Error::Internal( - "incoming_data_channels_rx is closed (no messages left)".to_string(), - ))), + None => { + debug_assert!( + false, + "Sender-end of channel should be owned by `RTCPeerConnection`" + ); + + return Poll::Pending; // Return `Pending` without registering a waker: If the channel is closed, we don't need to be called anymore. + } } } From 8c8feaa2a5ff1bba59b8c1d91808011bfea6a6f7 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Tue, 11 Oct 2022 18:26:18 +1100 Subject: [PATCH 33/39] Remove unnecessary dependency --- transports/webrtc/Cargo.toml | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/transports/webrtc/Cargo.toml b/transports/webrtc/Cargo.toml index 04bc70a0a81..5872a676bb5 100644 --- a/transports/webrtc/Cargo.toml +++ b/transports/webrtc/Cargo.toml @@ -42,10 +42,9 @@ prost-build = "0.11" [dev-dependencies] anyhow = "1.0" env_logger = "0.9" -libp2p = { path = "../..", features = ["request-response", "webrtc"], default-features = false } -rand_core = "0.5" -quickcheck = "1" hex-literal = "0.3" +libp2p = { path = "../..", features = ["request-response", "webrtc"], default-features = false } multihash = { version = "0.16", default-features = false, features = ["sha3"] } +quickcheck = "1" +rand_core = "0.5" unsigned-varint = { version = "0.7", features = ["asynchronous_codec"] } -asynchronous-codec = { version = "0.6" } From e6c177c6b2f17d065cceea39adbb384528ecc9a9 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Tue, 11 Oct 2022 18:26:27 +1100 Subject: [PATCH 34/39] Group imports --- transports/webrtc/src/substream.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/transports/webrtc/src/substream.rs b/transports/webrtc/src/substream.rs index d0accc85255..08d8874b54e 100644 --- a/transports/webrtc/src/substream.rs +++ b/transports/webrtc/src/substream.rs @@ -34,8 +34,7 @@ use std::{ task::{Context, Poll}, }; -use crate::message_proto::message::Flag; -use crate::message_proto::Message; +use crate::message_proto::{message::Flag, Message}; /// Maximum length of a message, in bytes. const MAX_MSG_LEN: usize = 16384; // 16kiB From b2961a050e0c52ec2959f7759e8c32c2751125fd Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Tue, 11 Oct 2022 18:26:32 +1100 Subject: [PATCH 35/39] Add spec wording to constant --- transports/webrtc/src/substream.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/transports/webrtc/src/substream.rs b/transports/webrtc/src/substream.rs index 08d8874b54e..82129d997b9 100644 --- a/transports/webrtc/src/substream.rs +++ b/transports/webrtc/src/substream.rs @@ -36,7 +36,8 @@ use std::{ use crate::message_proto::{message::Flag, Message}; -/// Maximum length of a message, in bytes. +/// As long as message interleaving is not supported, the sender SHOULD limit the maximum message size to 16 KB to avoid monopolization. +// Source: const MAX_MSG_LEN: usize = 16384; // 16kiB /// Length of varint, in bytes. const VARINT_LEN: usize = 2; From f827f62efcc6e814ad5902904a3c0b6f6618f359 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Tue, 11 Oct 2022 18:48:07 +1100 Subject: [PATCH 36/39] Introduce dedicated `State` submodule --- transports/webrtc/src/substream.rs | 481 +--------------------- transports/webrtc/src/substream/state.rs | 488 +++++++++++++++++++++++ 2 files changed, 491 insertions(+), 478 deletions(-) create mode 100644 transports/webrtc/src/substream/state.rs diff --git a/transports/webrtc/src/substream.rs b/transports/webrtc/src/substream.rs index 82129d997b9..18cc506fb24 100644 --- a/transports/webrtc/src/substream.rs +++ b/transports/webrtc/src/substream.rs @@ -27,6 +27,7 @@ use tokio_util::compat::TokioAsyncReadCompatExt; use webrtc::data::data_channel::DataChannel; use webrtc::data::data_channel::PollDataChannel; +use state::{Closing, State}; use std::{ io, pin::Pin, @@ -36,6 +37,8 @@ use std::{ use crate::message_proto::{message::Flag, Message}; +mod state; + /// As long as message interleaving is not supported, the sender SHOULD limit the maximum message size to 16 KB to avoid monopolization. // Source: const MAX_MSG_LEN: usize = 16384; // 16kiB @@ -238,492 +241,14 @@ fn io_poll_next( } } -#[derive(Debug, Copy, Clone)] -enum State { - Open, - ReadClosed, - WriteClosed, - ClosingRead { - /// Whether the write side of our channel was already closed. - write_closed: bool, - inner: Closing, - }, - ClosingWrite { - /// Whether the write side of our channel was already closed. - read_closed: bool, - inner: Closing, - }, - BothClosed { - reset: bool, - }, -} - -/// Represents the state of closing one half (either read or write) of the connection. -/// -/// Gracefully closing the read or write requires sending the `STOP_SENDING` or `FIN` flag respectively -/// and flushing the underlying connection. -#[derive(Debug, Copy, Clone)] -enum Closing { - Requested, - MessageSent, -} - -impl State { - /// Performs a state transition for a flag contained in an inbound message. - fn handle_inbound_flag(&mut self, flag: Flag, buffer: &mut Bytes) { - let current = *self; - - match (current, flag) { - (Self::Open, Flag::Fin) => { - *self = Self::ReadClosed; - } - (Self::WriteClosed, Flag::Fin) => { - *self = Self::BothClosed { reset: false }; - } - (Self::Open, Flag::StopSending) => { - *self = Self::WriteClosed; - } - (Self::ReadClosed, Flag::StopSending) => { - *self = Self::BothClosed { reset: false }; - } - (_, Flag::Reset) => { - buffer.clear(); - *self = Self::BothClosed { reset: true }; - } - _ => {} - } - } - - fn write_closed(&mut self) { - match self { - State::ClosingWrite { - read_closed: true, - inner, - } => { - debug_assert!(matches!(inner, Closing::MessageSent)); - - *self = State::BothClosed { reset: false }; - } - State::ClosingWrite { - read_closed: false, - inner, - } => { - debug_assert!(matches!(inner, Closing::MessageSent)); - - *self = State::WriteClosed; - } - State::Open - | State::ReadClosed - | State::WriteClosed - | State::ClosingRead { .. } - | State::BothClosed { .. } => { - unreachable!("bad state machine impl") - } - } - } - - fn close_write_message_sent(&mut self) { - match self { - State::ClosingWrite { inner, read_closed } => { - debug_assert!(matches!(inner, Closing::Requested)); - - *self = State::ClosingWrite { - read_closed: *read_closed, - inner: Closing::MessageSent, - }; - } - State::Open - | State::ReadClosed - | State::WriteClosed - | State::ClosingRead { .. } - | State::BothClosed { .. } => { - unreachable!("bad state machine impl") - } - } - } - - fn read_closed(&mut self) { - match self { - State::ClosingRead { - write_closed: true, - inner, - } => { - debug_assert!(matches!(inner, Closing::MessageSent)); - - *self = State::BothClosed { reset: false }; - } - State::ClosingRead { - write_closed: false, - inner, - } => { - debug_assert!(matches!(inner, Closing::MessageSent)); - - *self = State::ReadClosed; - } - State::Open - | State::ReadClosed - | State::WriteClosed - | State::ClosingWrite { .. } - | State::BothClosed { .. } => { - unreachable!("bad state machine impl") - } - } - } - - fn close_read_message_sent(&mut self) { - match self { - State::ClosingRead { - inner, - write_closed, - } => { - debug_assert!(matches!(inner, Closing::Requested)); - - *self = State::ClosingRead { - write_closed: *write_closed, - inner: Closing::MessageSent, - }; - } - State::Open - | State::ReadClosed - | State::WriteClosed - | State::ClosingWrite { .. } - | State::BothClosed { .. } => { - unreachable!("bad state machine impl") - } - } - } - - /// Whether we should read from the stream in the [`AsyncWrite`] implementation. - /// - /// This is necessary for read-closed streams because we would otherwise not read any more flags from - /// the socket. - fn read_flags_in_async_write(&self) -> bool { - matches!(self, Self::ReadClosed) - } - - /// Acts as a "barrier" for [`AsyncRead::poll_read`]. - fn read_barrier(&self) -> io::Result<()> { - use State::*; - - let kind = match self { - Open - | WriteClosed - | ClosingWrite { - read_closed: false, .. - } => return Ok(()), - ClosingWrite { - read_closed: true, .. - } - | ReadClosed - | ClosingRead { .. } - | BothClosed { reset: false } => io::ErrorKind::BrokenPipe, - BothClosed { reset: true } => io::ErrorKind::ConnectionReset, - }; - - Err(kind.into()) - } - - /// Acts as a "barrier" for [`AsyncWrite::poll_write`]. - fn write_barrier(&self) -> io::Result<()> { - use State::*; - - let kind = match self { - Open - | ReadClosed - | ClosingRead { - write_closed: false, - .. - } => return Ok(()), - ClosingRead { - write_closed: true, .. - } - | WriteClosed - | ClosingWrite { .. } - | BothClosed { reset: false } => io::ErrorKind::BrokenPipe, - BothClosed { reset: true } => io::ErrorKind::ConnectionReset, - }; - - Err(kind.into()) - } - - /// Acts as a "barrier" for [`AsyncWrite::poll_close`]. - fn close_write_barrier(&mut self) -> io::Result> { - loop { - match &self { - State::WriteClosed => return Ok(None), - - State::ClosingWrite { inner, .. } => return Ok(Some(*inner)), - - State::Open => { - *self = Self::ClosingWrite { - read_closed: false, - inner: Closing::Requested, - }; - } - State::ReadClosed => { - *self = Self::ClosingWrite { - read_closed: true, - inner: Closing::Requested, - }; - } - - State::ClosingRead { - write_closed: true, .. - } - | State::BothClosed { reset: false } => { - return Err(io::ErrorKind::BrokenPipe.into()) - } - - State::ClosingRead { - write_closed: false, - .. - } => { - return Err(io::Error::new( - io::ErrorKind::Other, - "cannot close read half while closing write half", - )) - } - - State::BothClosed { reset: true } => { - return Err(io::ErrorKind::ConnectionReset.into()) - } - } - } - } - - /// Acts as a "barrier" for [`Substream::poll_close_read`]. - fn close_read_barrier(&mut self) -> io::Result> { - loop { - match self { - State::ReadClosed => return Ok(None), - - State::ClosingRead { inner, .. } => return Ok(Some(*inner)), - - State::Open => { - *self = Self::ClosingRead { - write_closed: false, - inner: Closing::Requested, - }; - } - State::WriteClosed => { - *self = Self::ClosingRead { - write_closed: true, - inner: Closing::Requested, - }; - } - - State::ClosingWrite { - read_closed: true, .. - } - | State::BothClosed { reset: false } => { - return Err(io::ErrorKind::BrokenPipe.into()) - } - - State::ClosingWrite { - read_closed: false, .. - } => { - return Err(io::Error::new( - io::ErrorKind::Other, - "cannot close write half while closing read half", - )) - } - - State::BothClosed { reset: true } => { - return Err(io::ErrorKind::ConnectionReset.into()) - } - } - } - } -} - #[cfg(test)] mod tests { use super::*; use asynchronous_codec::Encoder; use bytes::BytesMut; use prost::Message; - use std::io::ErrorKind; use unsigned_varint::codec::UviBytes; - #[test] - fn cannot_read_after_receiving_fin() { - let mut open = State::Open; - - open.handle_inbound_flag(Flag::Fin, &mut Bytes::default()); - let error = open.read_barrier().unwrap_err(); - - assert_eq!(error.kind(), ErrorKind::BrokenPipe) - } - - #[test] - fn cannot_read_after_closing_read() { - let mut open = State::Open; - - open.close_read_barrier().unwrap(); - open.close_read_message_sent(); - open.read_closed(); - let error = open.read_barrier().unwrap_err(); - - assert_eq!(error.kind(), ErrorKind::BrokenPipe) - } - - #[test] - fn cannot_write_after_receiving_stop_sending() { - let mut open = State::Open; - - open.handle_inbound_flag(Flag::StopSending, &mut Bytes::default()); - let error = open.write_barrier().unwrap_err(); - - assert_eq!(error.kind(), ErrorKind::BrokenPipe) - } - - #[test] - fn cannot_write_after_closing_write() { - let mut open = State::Open; - - open.close_write_barrier().unwrap(); - open.close_write_message_sent(); - open.write_closed(); - let error = open.write_barrier().unwrap_err(); - - assert_eq!(error.kind(), ErrorKind::BrokenPipe) - } - - #[test] - fn everything_broken_after_receiving_reset() { - let mut open = State::Open; - - open.handle_inbound_flag(Flag::Reset, &mut Bytes::default()); - let error1 = open.read_barrier().unwrap_err(); - let error2 = open.write_barrier().unwrap_err(); - let error3 = open.close_write_barrier().unwrap_err(); - let error4 = open.close_read_barrier().unwrap_err(); - - assert_eq!(error1.kind(), ErrorKind::ConnectionReset); - assert_eq!(error2.kind(), ErrorKind::ConnectionReset); - assert_eq!(error3.kind(), ErrorKind::ConnectionReset); - assert_eq!(error4.kind(), ErrorKind::ConnectionReset); - } - - #[test] - fn should_read_flags_in_async_write_after_read_closed() { - let mut open = State::Open; - - open.handle_inbound_flag(Flag::Fin, &mut Bytes::default()); - - assert!(open.read_flags_in_async_write()) - } - - #[test] - fn cannot_read_or_write_after_receiving_fin_and_stop_sending() { - let mut open = State::Open; - - open.handle_inbound_flag(Flag::Fin, &mut Bytes::default()); - open.handle_inbound_flag(Flag::StopSending, &mut Bytes::default()); - - let error1 = open.read_barrier().unwrap_err(); - let error2 = open.write_barrier().unwrap_err(); - - assert_eq!(error1.kind(), ErrorKind::BrokenPipe); - assert_eq!(error2.kind(), ErrorKind::BrokenPipe); - } - - #[test] - fn can_read_after_closing_write() { - let mut open = State::Open; - - open.close_write_barrier().unwrap(); - open.close_write_message_sent(); - open.write_closed(); - - open.read_barrier().unwrap(); - } - - #[test] - fn can_write_after_closing_read() { - let mut open = State::Open; - - open.close_read_barrier().unwrap(); - open.close_read_message_sent(); - open.read_closed(); - - open.write_barrier().unwrap(); - } - - #[test] - fn cannot_write_after_starting_close() { - let mut open = State::Open; - - open.close_write_barrier().expect("to close in open"); - let error = open.write_barrier().unwrap_err(); - - assert_eq!(error.kind(), ErrorKind::BrokenPipe); - } - - #[test] - fn cannot_read_after_starting_close() { - let mut open = State::Open; - - open.close_read_barrier().expect("to close in open"); - let error = open.read_barrier().unwrap_err(); - - assert_eq!(error.kind(), ErrorKind::BrokenPipe); - } - - #[test] - fn can_read_in_open() { - let open = State::Open; - - let result = open.read_barrier(); - - result.unwrap(); - } - - #[test] - fn can_write_in_open() { - let open = State::Open; - - let result = open.write_barrier(); - - result.unwrap(); - } - - #[test] - fn write_close_barrier_returns_ok_when_closed() { - let mut open = State::Open; - - open.close_write_barrier().unwrap(); - open.close_write_message_sent(); - open.write_closed(); - - let maybe = open.close_write_barrier().unwrap(); - - assert!(maybe.is_none()) - } - - #[test] - fn read_close_barrier_returns_ok_when_closed() { - let mut open = State::Open; - - open.close_read_barrier().unwrap(); - open.close_read_message_sent(); - open.read_closed(); - - let maybe = open.close_read_barrier().unwrap(); - - assert!(maybe.is_none()) - } - - #[test] - fn reset_flag_clears_buffer() { - let mut open = State::Open; - let mut buffer = Bytes::copy_from_slice(b"foobar"); - - open.handle_inbound_flag(Flag::Reset, &mut buffer); - - assert!(buffer.is_empty()); - } - #[test] fn max_data_len() { // Largest possible message. diff --git a/transports/webrtc/src/substream/state.rs b/transports/webrtc/src/substream/state.rs new file mode 100644 index 00000000000..f300951c634 --- /dev/null +++ b/transports/webrtc/src/substream/state.rs @@ -0,0 +1,488 @@ +use crate::message_proto::message::Flag; +use bytes::Bytes; +use std::io; + +#[derive(Debug, Copy, Clone)] +pub enum State { + Open, + ReadClosed, + WriteClosed, + ClosingRead { + /// Whether the write side of our channel was already closed. + write_closed: bool, + inner: Closing, + }, + ClosingWrite { + /// Whether the write side of our channel was already closed. + read_closed: bool, + inner: Closing, + }, + BothClosed { + reset: bool, + }, +} + +/// Represents the state of closing one half (either read or write) of the connection. +/// +/// Gracefully closing the read or write requires sending the `STOP_SENDING` or `FIN` flag respectively +/// and flushing the underlying connection. +#[derive(Debug, Copy, Clone)] +pub enum Closing { + Requested, + MessageSent, +} + +impl State { + /// Performs a state transition for a flag contained in an inbound message. + pub(crate) fn handle_inbound_flag(&mut self, flag: Flag, buffer: &mut Bytes) { + let current = *self; + + match (current, flag) { + (Self::Open, Flag::Fin) => { + *self = Self::ReadClosed; + } + (Self::WriteClosed, Flag::Fin) => { + *self = Self::BothClosed { reset: false }; + } + (Self::Open, Flag::StopSending) => { + *self = Self::WriteClosed; + } + (Self::ReadClosed, Flag::StopSending) => { + *self = Self::BothClosed { reset: false }; + } + (_, Flag::Reset) => { + buffer.clear(); + *self = Self::BothClosed { reset: true }; + } + _ => {} + } + } + + pub(crate) fn write_closed(&mut self) { + match self { + State::ClosingWrite { + read_closed: true, + inner, + } => { + debug_assert!(matches!(inner, Closing::MessageSent)); + + *self = State::BothClosed { reset: false }; + } + State::ClosingWrite { + read_closed: false, + inner, + } => { + debug_assert!(matches!(inner, Closing::MessageSent)); + + *self = State::WriteClosed; + } + State::Open + | State::ReadClosed + | State::WriteClosed + | State::ClosingRead { .. } + | State::BothClosed { .. } => { + unreachable!("bad state machine impl") + } + } + } + + pub(crate) fn close_write_message_sent(&mut self) { + match self { + State::ClosingWrite { inner, read_closed } => { + debug_assert!(matches!(inner, Closing::Requested)); + + *self = State::ClosingWrite { + read_closed: *read_closed, + inner: Closing::MessageSent, + }; + } + State::Open + | State::ReadClosed + | State::WriteClosed + | State::ClosingRead { .. } + | State::BothClosed { .. } => { + unreachable!("bad state machine impl") + } + } + } + + pub(crate) fn read_closed(&mut self) { + match self { + State::ClosingRead { + write_closed: true, + inner, + } => { + debug_assert!(matches!(inner, Closing::MessageSent)); + + *self = State::BothClosed { reset: false }; + } + State::ClosingRead { + write_closed: false, + inner, + } => { + debug_assert!(matches!(inner, Closing::MessageSent)); + + *self = State::ReadClosed; + } + State::Open + | State::ReadClosed + | State::WriteClosed + | State::ClosingWrite { .. } + | State::BothClosed { .. } => { + unreachable!("bad state machine impl") + } + } + } + + pub(crate) fn close_read_message_sent(&mut self) { + match self { + State::ClosingRead { + inner, + write_closed, + } => { + debug_assert!(matches!(inner, Closing::Requested)); + + *self = State::ClosingRead { + write_closed: *write_closed, + inner: Closing::MessageSent, + }; + } + State::Open + | State::ReadClosed + | State::WriteClosed + | State::ClosingWrite { .. } + | State::BothClosed { .. } => { + unreachable!("bad state machine impl") + } + } + } + + /// Whether we should read from the stream in the [`AsyncWrite`] implementation. + /// + /// This is necessary for read-closed streams because we would otherwise not read any more flags from + /// the socket. + pub(crate) fn read_flags_in_async_write(&self) -> bool { + matches!(self, Self::ReadClosed) + } + + /// Acts as a "barrier" for [`AsyncRead::poll_read`]. + pub(crate) fn read_barrier(&self) -> io::Result<()> { + use crate::substream::State::{Open, ReadClosed, WriteClosed}; + use State::*; + + let kind = match self { + Open + | WriteClosed + | ClosingWrite { + read_closed: false, .. + } => return Ok(()), + ClosingWrite { + read_closed: true, .. + } + | ReadClosed + | ClosingRead { .. } + | BothClosed { reset: false } => io::ErrorKind::BrokenPipe, + BothClosed { reset: true } => io::ErrorKind::ConnectionReset, + }; + + Err(kind.into()) + } + + /// Acts as a "barrier" for [`AsyncWrite::poll_write`]. + pub(crate) fn write_barrier(&self) -> io::Result<()> { + use crate::substream::State::{Open, ReadClosed, WriteClosed}; + use State::*; + + let kind = match self { + Open + | ReadClosed + | ClosingRead { + write_closed: false, + .. + } => return Ok(()), + ClosingRead { + write_closed: true, .. + } + | WriteClosed + | ClosingWrite { .. } + | BothClosed { reset: false } => io::ErrorKind::BrokenPipe, + BothClosed { reset: true } => io::ErrorKind::ConnectionReset, + }; + + Err(kind.into()) + } + + /// Acts as a "barrier" for [`AsyncWrite::poll_close`]. + pub(crate) fn close_write_barrier(&mut self) -> io::Result> { + loop { + match &self { + State::WriteClosed => return Ok(None), + + State::ClosingWrite { inner, .. } => return Ok(Some(*inner)), + + State::Open => { + *self = Self::ClosingWrite { + read_closed: false, + inner: Closing::Requested, + }; + } + State::ReadClosed => { + *self = Self::ClosingWrite { + read_closed: true, + inner: Closing::Requested, + }; + } + + State::ClosingRead { + write_closed: true, .. + } + | State::BothClosed { reset: false } => { + return Err(io::ErrorKind::BrokenPipe.into()) + } + + State::ClosingRead { + write_closed: false, + .. + } => { + return Err(io::Error::new( + io::ErrorKind::Other, + "cannot close read half while closing write half", + )) + } + + State::BothClosed { reset: true } => { + return Err(io::ErrorKind::ConnectionReset.into()) + } + } + } + } + + /// Acts as a "barrier" for [`Substream::poll_close_read`]. + pub fn close_read_barrier(&mut self) -> io::Result> { + loop { + match self { + State::ReadClosed => return Ok(None), + + State::ClosingRead { inner, .. } => return Ok(Some(*inner)), + + State::Open => { + *self = Self::ClosingRead { + write_closed: false, + inner: Closing::Requested, + }; + } + State::WriteClosed => { + *self = Self::ClosingRead { + write_closed: true, + inner: Closing::Requested, + }; + } + + State::ClosingWrite { + read_closed: true, .. + } + | State::BothClosed { reset: false } => { + return Err(io::ErrorKind::BrokenPipe.into()) + } + + State::ClosingWrite { + read_closed: false, .. + } => { + return Err(io::Error::new( + io::ErrorKind::Other, + "cannot close write half while closing read half", + )) + } + + State::BothClosed { reset: true } => { + return Err(io::ErrorKind::ConnectionReset.into()) + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::io::ErrorKind; + + #[test] + fn cannot_read_after_receiving_fin() { + let mut open = State::Open; + + open.handle_inbound_flag(Flag::Fin, &mut Bytes::default()); + let error = open.read_barrier().unwrap_err(); + + assert_eq!(error.kind(), ErrorKind::BrokenPipe) + } + + #[test] + fn cannot_read_after_closing_read() { + let mut open = State::Open; + + open.close_read_barrier().unwrap(); + open.close_read_message_sent(); + open.read_closed(); + let error = open.read_barrier().unwrap_err(); + + assert_eq!(error.kind(), ErrorKind::BrokenPipe) + } + + #[test] + fn cannot_write_after_receiving_stop_sending() { + let mut open = State::Open; + + open.handle_inbound_flag(Flag::StopSending, &mut Bytes::default()); + let error = open.write_barrier().unwrap_err(); + + assert_eq!(error.kind(), ErrorKind::BrokenPipe) + } + + #[test] + fn cannot_write_after_closing_write() { + let mut open = State::Open; + + open.close_write_barrier().unwrap(); + open.close_write_message_sent(); + open.write_closed(); + let error = open.write_barrier().unwrap_err(); + + assert_eq!(error.kind(), ErrorKind::BrokenPipe) + } + + #[test] + fn everything_broken_after_receiving_reset() { + let mut open = State::Open; + + open.handle_inbound_flag(Flag::Reset, &mut Bytes::default()); + let error1 = open.read_barrier().unwrap_err(); + let error2 = open.write_barrier().unwrap_err(); + let error3 = open.close_write_barrier().unwrap_err(); + let error4 = open.close_read_barrier().unwrap_err(); + + assert_eq!(error1.kind(), ErrorKind::ConnectionReset); + assert_eq!(error2.kind(), ErrorKind::ConnectionReset); + assert_eq!(error3.kind(), ErrorKind::ConnectionReset); + assert_eq!(error4.kind(), ErrorKind::ConnectionReset); + } + + #[test] + fn should_read_flags_in_async_write_after_read_closed() { + let mut open = State::Open; + + open.handle_inbound_flag(Flag::Fin, &mut Bytes::default()); + + assert!(open.read_flags_in_async_write()) + } + + #[test] + fn cannot_read_or_write_after_receiving_fin_and_stop_sending() { + let mut open = State::Open; + + open.handle_inbound_flag(Flag::Fin, &mut Bytes::default()); + open.handle_inbound_flag(Flag::StopSending, &mut Bytes::default()); + + let error1 = open.read_barrier().unwrap_err(); + let error2 = open.write_barrier().unwrap_err(); + + assert_eq!(error1.kind(), ErrorKind::BrokenPipe); + assert_eq!(error2.kind(), ErrorKind::BrokenPipe); + } + + #[test] + fn can_read_after_closing_write() { + let mut open = State::Open; + + open.close_write_barrier().unwrap(); + open.close_write_message_sent(); + open.write_closed(); + + open.read_barrier().unwrap(); + } + + #[test] + fn can_write_after_closing_read() { + let mut open = State::Open; + + open.close_read_barrier().unwrap(); + open.close_read_message_sent(); + open.read_closed(); + + open.write_barrier().unwrap(); + } + + #[test] + fn cannot_write_after_starting_close() { + let mut open = State::Open; + + open.close_write_barrier().expect("to close in open"); + let error = open.write_barrier().unwrap_err(); + + assert_eq!(error.kind(), ErrorKind::BrokenPipe); + } + + #[test] + fn cannot_read_after_starting_close() { + let mut open = State::Open; + + open.close_read_barrier().expect("to close in open"); + let error = open.read_barrier().unwrap_err(); + + assert_eq!(error.kind(), ErrorKind::BrokenPipe); + } + + #[test] + fn can_read_in_open() { + let open = State::Open; + + let result = open.read_barrier(); + + result.unwrap(); + } + + #[test] + fn can_write_in_open() { + let open = State::Open; + + let result = open.write_barrier(); + + result.unwrap(); + } + + #[test] + fn write_close_barrier_returns_ok_when_closed() { + let mut open = State::Open; + + open.close_write_barrier().unwrap(); + open.close_write_message_sent(); + open.write_closed(); + + let maybe = open.close_write_barrier().unwrap(); + + assert!(maybe.is_none()) + } + + #[test] + fn read_close_barrier_returns_ok_when_closed() { + let mut open = State::Open; + + open.close_read_barrier().unwrap(); + open.close_read_message_sent(); + open.read_closed(); + + let maybe = open.close_read_barrier().unwrap(); + + assert!(maybe.is_none()) + } + + #[test] + fn reset_flag_clears_buffer() { + let mut open = State::Open; + let mut buffer = Bytes::copy_from_slice(b"foobar"); + + open.handle_inbound_flag(Flag::Reset, &mut buffer); + + assert!(buffer.is_empty()); + } +} From 16433db8f935b754a837dd22c9c33d802b7b4a93 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Tue, 11 Oct 2022 19:17:13 +1100 Subject: [PATCH 37/39] Send reset flag for dropped substreams --- transports/webrtc/src/connection.rs | 47 ++++++-- transports/webrtc/src/substream.rs | 39 +++++-- .../webrtc/src/substream/drop_listener.rs | 108 ++++++++++++++++++ transports/webrtc/src/substream/framed_dc.rs | 21 ++++ transports/webrtc/src/upgrade.rs | 5 +- 5 files changed, 198 insertions(+), 22 deletions(-) create mode 100644 transports/webrtc/src/substream/drop_listener.rs create mode 100644 transports/webrtc/src/substream/framed_dc.rs diff --git a/transports/webrtc/src/connection.rs b/transports/webrtc/src/connection.rs index 4d867ceeae4..feca0b10e53 100644 --- a/transports/webrtc/src/connection.rs +++ b/transports/webrtc/src/connection.rs @@ -18,29 +18,30 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. +use futures::stream::FuturesUnordered; use futures::{ channel::{ mpsc, oneshot::{self, Sender}, }, lock::Mutex as FutMutex, + StreamExt, {future::BoxFuture, ready}, }; -use futures_lite::StreamExt; use libp2p_core::muxing::{StreamMuxer, StreamMuxerEvent}; use log::{debug, error, trace}; use webrtc::data::data_channel::DataChannel as DetachedDataChannel; use webrtc::data_channel::RTCDataChannel; use webrtc::peer_connection::RTCPeerConnection; +use std::task::Waker; use std::{ pin::Pin, sync::Arc, task::{Context, Poll}, }; -use crate::error::Error; -use crate::substream::Substream; +use crate::{error::Error, substream, substream::Substream}; const MAX_DATA_CHANNELS_IN_FLIGHT: usize = 10; @@ -59,6 +60,9 @@ pub struct Connection { /// Future, which, once polled, will result in closing the entire connection. close_fut: Option>>, + + drop_listeners: FuturesUnordered, + no_drop_listeners_waker: Option, } impl Unpin for Connection {} @@ -75,6 +79,8 @@ impl Connection { incoming_data_channels_rx: data_channel_rx, outbound_fut: None, close_fut: None, + drop_listeners: FuturesUnordered::default(), + no_drop_listeners_waker: None, } } @@ -143,11 +149,17 @@ impl StreamMuxer for Connection { mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { - match ready!(self.incoming_data_channels_rx.poll_next(cx)) { + match ready!(self.incoming_data_channels_rx.poll_next_unpin(cx)) { Some(detached) => { trace!("Incoming substream {}", detached.stream_identifier()); - Poll::Ready(Ok(Substream::new(detached))) + let (substream, drop_listener) = Substream::new(detached); + self.drop_listeners.push(drop_listener); + if let Some(waker) = self.no_drop_listeners_waker.take() { + waker.wake() + } + + Poll::Ready(Ok(substream)) } None => { debug_assert!( @@ -161,10 +173,21 @@ impl StreamMuxer for Connection { } fn poll( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, ) -> Poll> { - Poll::Pending + loop { + match ready!(self.drop_listeners.poll_next_unpin(cx)) { + Some(Ok(())) => {} + Some(Err(e)) => { + log::debug!("a DropListener failed: {e}") + } + None => { + self.no_drop_listeners_waker = Some(cx.waker().clone()); + return Poll::Pending; + } + } + } } fn poll_outbound( @@ -198,7 +221,13 @@ impl StreamMuxer for Connection { match ready!(fut.as_mut().poll(cx)) { Ok(detached) => { self.outbound_fut = None; - Poll::Ready(Ok(Substream::new(detached))) + let (substream, drop_listener) = Substream::new(detached); + self.drop_listeners.push(drop_listener); + if let Some(waker) = self.no_drop_listeners_waker.take() { + waker.wake() + } + + Poll::Ready(Ok(substream)) } Err(e) => { self.outbound_fut = None; diff --git a/transports/webrtc/src/substream.rs b/transports/webrtc/src/substream.rs index 18cc506fb24..b4edd2b8aac 100644 --- a/transports/webrtc/src/substream.rs +++ b/transports/webrtc/src/substream.rs @@ -20,14 +20,13 @@ use asynchronous_codec::Framed; use bytes::Bytes; +use futures::channel::oneshot; use futures::prelude::*; use futures::ready; use tokio_util::compat::Compat; -use tokio_util::compat::TokioAsyncReadCompatExt; use webrtc::data::data_channel::DataChannel; use webrtc::data::data_channel::PollDataChannel; -use state::{Closing, State}; use std::{ io, pin::Pin, @@ -36,7 +35,12 @@ use std::{ }; use crate::message_proto::{message::Flag, Message}; +use crate::substream::drop_listener::GracefullyClosed; +use crate::substream::framed_dc::FramedDC; +use crate::substream::state::{Closing, State}; +mod drop_listener; +mod framed_dc; mod state; /// As long as message interleaving is not supported, the sender SHOULD limit the maximum message size to 16 KB to avoid monopolization. @@ -49,30 +53,34 @@ const PROTO_OVERHEAD: usize = 5; /// Maximum length of data, in bytes. const MAX_DATA_LEN: usize = MAX_MSG_LEN - VARINT_LEN - PROTO_OVERHEAD; +pub use drop_listener::DropListener; + /// A substream on top of a WebRTC data channel. /// /// To be a proper libp2p substream, we need to implement [`AsyncRead`] and [`AsyncWrite`] as well /// as support a half-closed state which we do by framing messages in a protobuf envelope. pub struct Substream { - io: Framed, prost_codec::Codec>, + io: FramedDC, state: State, read_buffer: Bytes, + /// Dropping this will close the oneshot and notify the receiver by emitting `Canceled`. + drop_notifier: Option>, } impl Substream { /// Constructs a new `Substream`. - pub(crate) fn new(data_channel: Arc) -> Self { - let mut inner = PollDataChannel::new(data_channel); - - // TODO: default buffer size is too small to fit some messages. Possibly remove once - // https://github.com/webrtc-rs/webrtc/issues/273 is fixed. - inner.set_read_buf_capacity(8192 * 10); + pub(crate) fn new(data_channel: Arc) -> (Self, DropListener) { + let (sender, receiver) = oneshot::channel(); - Self { - io: Framed::new(inner.compat(), prost_codec::Codec::new(MAX_MSG_LEN)), + let substream = Self { + io: framed_dc::new(data_channel.clone()), state: State::Open, read_buffer: Bytes::default(), - } + drop_notifier: Some(sender), + }; + let listener = DropListener::new(framed_dc::new(data_channel), receiver); + + (substream, listener) } /// Gracefully closes the "read-half" of the substream. @@ -124,6 +132,7 @@ impl AsyncRead for Substream { read_buffer, io, state, + .. } = &mut *self; match ready!(io_poll_next(io, cx))? { @@ -160,6 +169,7 @@ impl AsyncWrite for Substream { read_buffer, io, state, + .. } = &mut *self; match io_poll_next(io, cx)? { @@ -211,6 +221,11 @@ impl AsyncWrite for Substream { ready!(self.io.poll_flush_unpin(cx))?; self.state.write_closed(); + let _ = self + .drop_notifier + .take() + .expect("to not close twice") + .send(GracefullyClosed {}); return Poll::Ready(Ok(())); } diff --git a/transports/webrtc/src/substream/drop_listener.rs b/transports/webrtc/src/substream/drop_listener.rs new file mode 100644 index 00000000000..26fdc206641 --- /dev/null +++ b/transports/webrtc/src/substream/drop_listener.rs @@ -0,0 +1,108 @@ +use crate::message_proto::{message::Flag, Message}; +use crate::substream::framed_dc::FramedDC; +use futures::channel::oneshot; +use futures::channel::oneshot::Canceled; +use futures::{FutureExt, SinkExt}; +use std::future::Future; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; + +#[must_use] +pub struct DropListener { + state: State, +} + +impl DropListener { + pub fn new(stream: FramedDC, receiver: oneshot::Receiver) -> Self { + let substream_id = stream.get_ref().stream_identifier(); + + Self { + state: State::Idle { + stream, + receiver, + substream_id, + }, + } + } +} + +enum State { + /// The [`DropListener`] is idle and waiting to be activated. + Idle { + stream: FramedDC, + receiver: oneshot::Receiver, + substream_id: u16, + }, + /// The stream got dropped and we are sending a reset flag. + SendingReset { + stream: FramedDC, + }, + Flushing { + stream: FramedDC, + }, + /// Bad state transition. + Poisoned, +} + +impl Future for DropListener { + type Output = io::Result<()>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let state = &mut self.get_mut().state; + + loop { + match std::mem::replace(state, State::Poisoned) { + State::Idle { + stream, + substream_id, + mut receiver, + } => match receiver.poll_unpin(cx) { + Poll::Ready(Ok(GracefullyClosed {})) => { + return Poll::Ready(Ok(())); + } + Poll::Ready(Err(Canceled)) => { + log::info!("Substream {substream_id} dropped without graceful close, sending Reset"); + *state = State::SendingReset { stream }; + continue; + } + Poll::Pending => { + *state = State::Idle { + stream, + substream_id, + receiver, + }; + return Poll::Pending; + } + }, + State::SendingReset { mut stream } => match stream.poll_ready_unpin(cx)? { + Poll::Ready(()) => { + stream.start_send_unpin(Message { + flag: Some(Flag::Reset.into()), + message: None, + })?; + *state = State::Flushing { stream }; + continue; + } + Poll::Pending => { + *state = State::SendingReset { stream }; + return Poll::Pending; + } + }, + State::Flushing { mut stream } => match stream.poll_flush_unpin(cx)? { + Poll::Ready(()) => return Poll::Ready(Ok(())), + Poll::Pending => { + *state = State::Flushing { stream }; + return Poll::Pending; + } + }, + State::Poisoned => { + unreachable!() + } + } + } + } +} + +/// Indicates that our substream got gracefully closed. +pub struct GracefullyClosed {} diff --git a/transports/webrtc/src/substream/framed_dc.rs b/transports/webrtc/src/substream/framed_dc.rs new file mode 100644 index 00000000000..b28b0d467a0 --- /dev/null +++ b/transports/webrtc/src/substream/framed_dc.rs @@ -0,0 +1,21 @@ +use crate::message_proto::Message; +use crate::substream::MAX_MSG_LEN; +use asynchronous_codec::Framed; +use std::sync::Arc; +use tokio_util::compat::Compat; +use tokio_util::compat::TokioAsyncReadCompatExt; +use webrtc::data::data_channel::{DataChannel, PollDataChannel}; + +pub type FramedDC = Framed, prost_codec::Codec>; + +pub fn new( + data_channel: Arc, +) -> Framed, prost_codec::Codec> { + let mut inner = PollDataChannel::new(data_channel.clone()); + + // TODO: default buffer size is too small to fit some messages. Possibly remove once + // https://github.com/webrtc-rs/webrtc/issues/273 is fixed. + inner.set_read_buf_capacity(8192 * 10); + + Framed::new(inner.compat(), prost_codec::Codec::new(MAX_MSG_LEN)) +} diff --git a/transports/webrtc/src/upgrade.rs b/transports/webrtc/src/upgrade.rs index 1664d5a9db5..2a6999a3d54 100644 --- a/transports/webrtc/src/upgrade.rs +++ b/transports/webrtc/src/upgrade.rs @@ -223,5 +223,8 @@ async fn create_substream_for_noise_handshake( } }; - Ok(Substream::new(channel)) + let (substream, drop_listener) = Substream::new(channel); + drop(drop_listener); // Don't care about cancelled substreams during initial handshake. + + Ok(substream) } From 834ec644496d8bd42631c006d6a4849ac29799f5 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Tue, 11 Oct 2022 19:23:38 +1100 Subject: [PATCH 38/39] Fix clippy warnings --- transports/webrtc/src/connection.rs | 2 +- transports/webrtc/src/substream/framed_dc.rs | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/transports/webrtc/src/connection.rs b/transports/webrtc/src/connection.rs index feca0b10e53..d4e765bd85a 100644 --- a/transports/webrtc/src/connection.rs +++ b/transports/webrtc/src/connection.rs @@ -167,7 +167,7 @@ impl StreamMuxer for Connection { "Sender-end of channel should be owned by `RTCPeerConnection`" ); - return Poll::Pending; // Return `Pending` without registering a waker: If the channel is closed, we don't need to be called anymore. + Poll::Pending // Return `Pending` without registering a waker: If the channel is closed, we don't need to be called anymore. } } } diff --git a/transports/webrtc/src/substream/framed_dc.rs b/transports/webrtc/src/substream/framed_dc.rs index b28b0d467a0..63f985c1f4a 100644 --- a/transports/webrtc/src/substream/framed_dc.rs +++ b/transports/webrtc/src/substream/framed_dc.rs @@ -8,10 +8,8 @@ use webrtc::data::data_channel::{DataChannel, PollDataChannel}; pub type FramedDC = Framed, prost_codec::Codec>; -pub fn new( - data_channel: Arc, -) -> Framed, prost_codec::Codec> { - let mut inner = PollDataChannel::new(data_channel.clone()); +pub fn new(data_channel: Arc) -> FramedDC { + let mut inner = PollDataChannel::new(data_channel); // TODO: default buffer size is too small to fit some messages. Possibly remove once // https://github.com/webrtc-rs/webrtc/issues/273 is fixed. From 7e0e46d728c8ca496649ce7bc807e2e25dead7c6 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Tue, 11 Oct 2022 19:26:01 +1100 Subject: [PATCH 39/39] Fix docs --- transports/webrtc/src/fingerprint.rs | 2 +- transports/webrtc/src/substream/state.rs | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/transports/webrtc/src/fingerprint.rs b/transports/webrtc/src/fingerprint.rs index 0a307edeb36..b9c63fe1474 100644 --- a/transports/webrtc/src/fingerprint.rs +++ b/transports/webrtc/src/fingerprint.rs @@ -89,7 +89,7 @@ impl Fingerprint { } /// Returns the algorithm used (e.g. "sha-256"). - /// See https://datatracker.ietf.org/doc/html/rfc8122#section-5 + /// See pub fn algorithm(&self) -> String { SHA256.to_owned() } diff --git a/transports/webrtc/src/substream/state.rs b/transports/webrtc/src/substream/state.rs index f300951c634..5baf219b364 100644 --- a/transports/webrtc/src/substream/state.rs +++ b/transports/webrtc/src/substream/state.rs @@ -157,7 +157,7 @@ impl State { } } - /// Whether we should read from the stream in the [`AsyncWrite`] implementation. + /// Whether we should read from the stream in the [`futures::AsyncWrite`] implementation. /// /// This is necessary for read-closed streams because we would otherwise not read any more flags from /// the socket. @@ -165,7 +165,7 @@ impl State { matches!(self, Self::ReadClosed) } - /// Acts as a "barrier" for [`AsyncRead::poll_read`]. + /// Acts as a "barrier" for [`futures::AsyncRead::poll_read`]. pub(crate) fn read_barrier(&self) -> io::Result<()> { use crate::substream::State::{Open, ReadClosed, WriteClosed}; use State::*; @@ -188,7 +188,7 @@ impl State { Err(kind.into()) } - /// Acts as a "barrier" for [`AsyncWrite::poll_write`]. + /// Acts as a "barrier" for [`futures::AsyncWrite::poll_write`]. pub(crate) fn write_barrier(&self) -> io::Result<()> { use crate::substream::State::{Open, ReadClosed, WriteClosed}; use State::*; @@ -212,7 +212,7 @@ impl State { Err(kind.into()) } - /// Acts as a "barrier" for [`AsyncWrite::poll_close`]. + /// Acts as a "barrier" for [`futures::AsyncWrite::poll_close`]. pub(crate) fn close_write_barrier(&mut self) -> io::Result> { loop { match &self { @@ -257,7 +257,7 @@ impl State { } } - /// Acts as a "barrier" for [`Substream::poll_close_read`]. + /// Acts as a "barrier" for [`Substream::poll_close_read`](super::Substream::poll_close_read). pub fn close_read_barrier(&mut self) -> io::Result> { loop { match self {