Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

net/tcp: Add poll_accept, poll_shared_accept and poll_shutdown #533

Merged
merged 4 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 59 additions & 13 deletions glommio/src/net/tcp_socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,20 @@ use crate::{
yolo_accept,
},
reactor::Reactor,
sys::Source,
GlommioError,
};
use futures_lite::{
future::poll_fn,
io::{AsyncBufRead, AsyncRead, AsyncWrite},
ready,
stream::{self, Stream},
};
use nix::sys::socket::SockaddrStorage;
use pin_project_lite::pin_project;
use socket2::{Domain, Protocol, Socket, Type};
use std::{
cell::RefCell,
io,
net::{self, Shutdown, SocketAddr, ToSocketAddrs},
os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd},
Expand Down Expand Up @@ -75,6 +78,7 @@ type Result<T> = crate::Result<T, ()>;
pub struct TcpListener {
reactor: Weak<Reactor>,
listener: net::TcpListener,
current_source: RefCell<Option<Source>>,
}

impl FromRawFd for TcpListener {
Expand All @@ -86,6 +90,7 @@ impl FromRawFd for TcpListener {
TcpListener {
reactor: Rc::downgrade(&crate::executor().reactor()),
listener,
current_source: Default::default(),
}
}
}
Expand Down Expand Up @@ -132,6 +137,7 @@ impl TcpListener {
Ok(TcpListener {
reactor: Rc::downgrade(&crate::executor().reactor()),
listener,
current_source: Default::default(),
})
}

Expand Down Expand Up @@ -164,19 +170,38 @@ impl TcpListener {
/// [`TcpStream`]: struct.TcpStream.html
/// [`Send`]: https://doc.rust-lang.org/std/marker/trait.Send.html
pub async fn shared_accept(&self) -> Result<AcceptedTcpStream> {
let reactor = self.reactor.upgrade().unwrap();
let raw_fd = self.listener.as_raw_fd();
if let Some(r) = yolo_accept(raw_fd) {
match r {
Ok(fd) => {
return Ok(AcceptedTcpStream { fd });
poll_fn(|cx| self.poll_shared_accept(cx)).await
}

/// Poll version of [`shared_accept`].
///
/// [`shared_accept`]: TcpListener::shared_accept
pub fn poll_shared_accept(&self, cx: &mut Context<'_>) -> Poll<Result<AcceptedTcpStream>> {
let mut poll_source = |source: Source| match source.poll_collect_rw(cx) {
Poll::Ready(Ok(fd)) => Poll::Ready(Ok(AcceptedTcpStream { fd: fd as RawFd })),
Poll::Ready(Err(err)) => Poll::Ready(Err(GlommioError::IoError(err))),
Poll::Pending => {
*self.current_source.borrow_mut() = Some(source);
Poll::Pending
}
};
match self.current_source.take() {
Some(source) => poll_source(source),
None => {
let reactor = self.reactor.upgrade().unwrap();
let raw_fd = self.listener.as_raw_fd();
match yolo_accept(raw_fd) {
Some(r) => match r {
Ok(fd) => Poll::Ready(Ok(AcceptedTcpStream { fd })),
Err(err) => Poll::Ready(Err(GlommioError::IoError(err))),
},
None => {
let source = reactor.accept(self.listener.as_raw_fd());
poll_source(source)
}
}
Err(err) => return Err(GlommioError::IoError(err)),
}
}
let source = reactor.accept(self.listener.as_raw_fd());
let fd = source.collect_rw().await?;
Ok(AcceptedTcpStream { fd: fd as RawFd })
}

/// Accepts a new incoming TCP connection in this executor
Expand Down Expand Up @@ -208,6 +233,19 @@ impl TcpListener {
Ok(a.bind_to_executor())
}
Comment on lines 233 to 234
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should use poll_fn here like elsewhere for the sake of consistency


/// Poll version of [`accept`].
///
/// [`accept`]: TcpListener::accept
pub fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<Result<TcpStream>> {
match ready!(self.poll_shared_accept(cx)) {
Ok(a) => {
let a = a.bind_to_executor();
Poll::Ready(Ok(a))
}
Err(err) => Poll::Ready(Err(err)),
}
}

/// Creates a stream of incoming connections
///
/// # Examples
Expand Down Expand Up @@ -554,9 +592,17 @@ impl<B: RxBuf> TcpStream<B> {

/// Shuts down the read, write, or both halves of this connection.
pub async fn shutdown(&self, how: Shutdown) -> Result<()> {
poll_fn(|cx| self.stream.poll_shutdown(cx, how))
.await
.map_err(Into::into)
poll_fn(|cx| self.poll_shutdown(cx, how)).await
}

/// Polling version of [`shutdown`].
///
/// [`shutdown`]: TcpStream::shutdown
pub fn poll_shutdown(&self, cx: &mut Context<'_>, how: Shutdown) -> Poll<Result<()>> {
match ready!(self.stream.poll_shutdown(cx, how)) {
Ok(()) => Poll::Ready(Ok(())),
Err(err) => Poll::Ready(Err(err.into())),
}
Comment on lines +602 to +605
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
match ready!(self.stream.poll_shutdown(cx, how)) {
Ok(()) => Poll::Ready(Ok(())),
Err(err) => Poll::Ready(Err(err.into())),
}
ready!(self.stream.poll_shutdown(cx, how)).map_err(Into::into)

}

/// Sets the value of the `TCP_NODELAY` option on this socket.
Expand Down
19 changes: 10 additions & 9 deletions glommio/src/sys/source.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use std::{
path::PathBuf,
pin::Pin,
rc::Rc,
task::{Poll, Waker},
task::{Context, Poll, Waker},
time::Duration,
};

Expand Down Expand Up @@ -303,15 +303,16 @@ impl Source {
}

pub(crate) async fn collect_rw(&self) -> io::Result<usize> {
future::poll_fn(|cx| {
if let Some(result) = self.result() {
return Poll::Ready(result);
}
future::poll_fn(|cx| self.poll_collect_rw(cx)).await
}

self.add_waiter_many(cx.waker().clone());
Poll::Pending
})
.await
pub(crate) fn poll_collect_rw(&self, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
if let Some(result) = self.result() {
return Poll::Ready(result);
}

self.add_waiter_many(cx.waker().clone());
Poll::Pending
}
}

Expand Down
Loading